diff --git a/nummi/main/models.py b/nummi/main/models.py index 1884869..73e2fde 100644 --- a/nummi/main/models.py +++ b/nummi/main/models.py @@ -12,7 +12,10 @@ from django.utils.translation import gettext as _ class UserModel(models.Model): user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, verbose_name=_("User"), editable=False + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + verbose_name=_("User"), + editable=False, ) class Meta: @@ -330,6 +333,10 @@ class Snapshot(AccountModel): class NummiForm(ModelForm): template_name = "main/form/base.html" + def __init__(self, *args, **kwargs): + kwargs.pop("user", None) + super().__init__(*args, **kwargs) + class AccountForm(NummiForm): class Meta: @@ -348,6 +355,12 @@ class TransactionForm(NummiForm): model = Transaction fields = "__all__" + def __init__(self, *args, **kwargs): + _user = kwargs.pop("user") + super().__init__(*args, **kwargs) + self.fields["category"].queryset = Category.objects.filter(user=_user) + self.fields["account"].queryset = Account.objects.filter(user=_user) + class InvoiceForm(NummiForm): prefix = "invoice" @@ -361,3 +374,8 @@ class SnapshotForm(NummiForm): class Meta: model = Snapshot fields = "__all__" + + def __init__(self, *args, **kwargs): + _user = kwargs.pop("user") + super().__init__(*args, **kwargs) + self.fields["account"].queryset = Account.objects.filter(user=_user) diff --git a/nummi/main/views.py b/nummi/main/views.py index 9347e4e..36da9d3 100644 --- a/nummi/main/views.py +++ b/nummi/main/views.py @@ -48,8 +48,11 @@ class UserMixin(LoginRequiredMixin): def get_queryset(self, **kwargs): return super().get_queryset().filter(user=self.request.user) + def get_form_kwargs(self): + return super().get_form_kwargs() | {"user": self.request.user} -class UserCreateView(LoginRequiredMixin, CreateView): + +class UserCreateView(UserMixin, CreateView): def form_valid(self, form): form.instance.user = self.request.user return super().form_valid(form)