diff --git a/formtools/wizard/forms.py b/formtools/wizard/forms.py index 11cf422..31f601b 100644 --- a/formtools/wizard/forms.py +++ b/formtools/wizard/forms.py @@ -1,9 +1,21 @@ from django import forms +from django.core.exceptions import ValidationError class ManagementForm(forms.Form): """ ``ManagementForm`` is used to keep track of the current wizard step. """ + template_name = "django/forms/p.html" # Remove when Django 5.0 is minimal version. current_step = forms.CharField(widget=forms.HiddenInput) + + def __init__(self, steps, **kwargs): + self.steps = steps + super().__init__(**kwargs) + + def clean_current_step(self): + value = self.cleaned_data["current_step"] + if value not in self.steps: + raise ValidationError("Invalid step name.") + return value diff --git a/formtools/wizard/views.py b/formtools/wizard/views.py index aad0bdf..9ceb216 100644 --- a/formtools/wizard/views.py +++ b/formtools/wizard/views.py @@ -283,7 +283,7 @@ def post(self, *args, **kwargs): return self.render_goto_step(wizard_goto_step) # Check if form was refreshed - management_form = ManagementForm(self.request.POST, prefix=self.prefix) + management_form = ManagementForm(self.request.POST, steps=self.steps.all, prefix=self.prefix) if not management_form.is_valid(): raise SuspiciousOperation(_('ManagementForm data is missing or has been tampered.')) @@ -582,7 +582,7 @@ def get_context_data(self, form, **kwargs): context['wizard'] = { 'form': form, 'steps': self.steps, - 'management_form': ManagementForm(prefix=self.prefix, initial={ + 'management_form': ManagementForm(prefix=self.prefix, steps=self.steps.all, initial={ 'current_step': self.steps.current, }), } diff --git a/tests/wizard/wizardtests/tests.py b/tests/wizard/wizardtests/tests.py index 96a6c06..af13376 100644 --- a/tests/wizard/wizardtests/tests.py +++ b/tests/wizard/wizardtests/tests.py @@ -73,6 +73,18 @@ def test_form_post_mgmt_data_missing(self): # view should return HTTP 400 Bad Request self.assertEqual(response.status_code, 400) + def test_invalid_current_step_data(self): + wizard_step_data = self.wizard_step_data[0].copy() + + # Replace the current step with invalid data + for key in list(wizard_step_data.keys()): + if "current_step" in key: + wizard_step_data[key] = "not-a-valid-step" + + response = self.client.post(self.wizard_url, wizard_step_data) + # view should return HTTP 400 Bad Request + self.assertEqual(response.status_code, 400) + def test_form_post_success(self): response = self.client.post(self.wizard_url, self.wizard_step_data[0]) wizard = response.context['wizard']