diff --git a/formtools/wizard/views.py b/formtools/wizard/views.py index 07e0a6b..1f7b373 100644 --- a/formtools/wizard/views.py +++ b/formtools/wizard/views.py @@ -165,8 +165,6 @@ def get_initkwargs(cls, form_list=None, initial_dict=None, instance_dict=None, computed_form_list = OrderedDict() - assert len(form_list) > 0, 'at least one form is needed' - # walk through the passed form list for i, form in enumerate(form_list): if isinstance(form, (list, tuple)): @@ -401,7 +399,18 @@ def get_form_kwargs(self, step=None): """ return {} - def get_form(self, step=None, data=None, files=None): + def get_form_class(self, step): + """ + Returns the form class for step. + + If self.form_list is not empty then it is assumed the wizard has been + implemented according to the original form list generation strategy and the form + class is taken from there. If self.form_list is empty, however, then get the + form class from the dynamically generated list provided by get_form_list(). + """ + return self.form_list[step] if self.form_list else self.get_form_list()[step] + + def get_form(self, step=None, data=None, files=None, form_cls=None): """ Constructs the form for a given `step`. If no `step` is defined, the current step will be determined automatically. @@ -409,10 +418,13 @@ def get_form(self, step=None, data=None, files=None): The form will be initialized using the `data` argument to prefill the new form. If needed, instance or queryset (for `ModelForm` or `ModelFormSet`) will be added too. + + If form_cls is provided, this class will be instantiated rather than trying to + retrieve the class from the form list. """ if step is None: step = self.steps.current - form_class = self.get_form_list()[step] + form_class = form_cls or self.get_form_class(step) # prepare the kwargs for the form instance. kwargs = self.get_form_kwargs(step) kwargs.update({ @@ -490,21 +502,27 @@ def get_all_cleaned_data(self): cleaned_data.update(form_obj.cleaned_data) return cleaned_data - def get_cleaned_data_for_step(self, step): + def get_cleaned_data_for_step(self, step, form_cls=None): """ Returns the cleaned data for a given `step`. Before returning the cleaned data, the stored values are revalidated through the form. If the data doesn't validate, None will be returned. + + A form_cls can be provided to avoid having to query the class by calling + get_form_list(). This is useful when overriding get_form_list() to create a + dynamic form list but data from other steps is required. """ - if step in self.form_list: - form_obj = self.get_form( - step=step, - data=self.storage.get_step_data(step), - files=self.storage.get_step_files(step), - ) - if form_obj.is_valid(): - return form_obj.cleaned_data - return None + if self.form_list and step not in self.form_list: + return None + form_obj = self.get_form( + step=step, + data=self.storage.get_step_data(step), + files=self.storage.get_step_files(step), + form_cls=form_cls, + ) + if not form_obj.is_valid(): + return None + return form_obj.cleaned_data def get_next_step(self, step=None): """ diff --git a/tests/wizard/test_forms.py b/tests/wizard/test_forms.py index 9bee6cf..c3fbb9e 100644 --- a/tests/wizard/test_forms.py +++ b/tests/wizard/test_forms.py @@ -1,4 +1,5 @@ import sys +from collections import OrderedDict from importlib import import_module from django import forms, http @@ -93,11 +94,11 @@ def done(self, form_list, **kwargs): class TestWizardWithCustomGetFormList(TestWizard): - - form_list = [Step1] - def get_form_list(self): - return {'start': Step1, 'step2': Step2} + form_list = OrderedDict([('start', Step1), ('step2', Step2)]) + self.get_cleaned_data_for_step("step2", form_cls=Step2) + form_list["step3"] = Step3 + return form_list class FormTests(TestCase): @@ -159,21 +160,41 @@ def test_form_condition(self): response, instance = testform(request) self.assertEqual(instance.get_next_step(), 'step2') - def test_form_condition_avoid_recursion(self): + def test_form_condition_can_check_prior_step_data(self): + def step_check(wizard): + wizard.get_cleaned_data_for_step('start') + return False + + testform = TestWizard.as_view( + [('start', Step1), ('step2', Step2), ('step3', Step3)], + condition_dict={'step2': step_check} + ) + request = get_request() + old_limit = sys.getrecursionlimit() + sys.setrecursionlimit(80) + try: + response, instance = testform(request) + self.assertEqual(instance.get_next_step(), 'step3') + except RecursionError: + self.fail("RecursionError happened during wizard test.") + finally: + sys.setrecursionlimit(old_limit) + + def test_form_condition_future_can_check_future_step_data(self): def subsequent_step_check(wizard): data = wizard.get_cleaned_data_for_step('step3') or {} return data.get('foo') testform = TestWizard.as_view( [('start', Step1), ('step2', Step2), ('step3', Step3)], - condition_dict={'step3': subsequent_step_check} + condition_dict={'step2': subsequent_step_check} ) request = get_request() old_limit = sys.getrecursionlimit() sys.setrecursionlimit(80) try: response, instance = testform(request) - self.assertEqual(instance.get_next_step(), 'step2') + self.assertEqual(instance.get_next_step(), 'step3') except RecursionError: self.fail("RecursionError happened during wizard test.") finally: @@ -298,11 +319,18 @@ def test_get_form_list_default(self): def test_get_form_list_custom(self): request = get_request() - testform = TestWizardWithCustomGetFormList.as_view([('start', Step1)]) + testform = TestWizardWithCustomGetFormList.as_view() response, instance = testform(request) - form_list = instance.get_form_list() - self.assertEqual(form_list, {'start': Step1, 'step2': Step2}) + old_limit = sys.getrecursionlimit() + sys.setrecursionlimit(80) + try: + form_list = instance.get_form_list() + except RecursionError: + self.fail("RecursionError happened during wizard test.") + finally: + sys.setrecursionlimit(old_limit) + self.assertEqual(form_list, {'start': Step1, 'step2': Step2, 'step3': Step3}) self.assertIsInstance(instance.get_form('step2'), Step2)