diff --git a/formtools/wizard/views.py b/formtools/wizard/views.py index 596d622..8561217 100644 --- a/formtools/wizard/views.py +++ b/formtools/wizard/views.py @@ -46,7 +46,7 @@ def __repr__(self): @property def all(self): "Returns the names of all steps/forms." - return list(self._wizard.get_form_list()) + return list(self._wizard.form_list) @property def count(self): @@ -201,28 +201,21 @@ def get_prefix(self, request, *args, **kwargs): # TODO: Add some kind of unique id to prefix return normalize_name(self.__class__.__name__) - def get_form_list(self): + def process_condition_dict(self): """ - This method returns a form_list based on the initial form list but - checks if there is a condition method/value in the condition_list. - If an entry exists in the condition list, it will call/read the value - and respect the result. (True means add the form, False means ignore - the form) - - The form_list is always generated on the fly because condition methods - could use data from other (maybe previous forms). + This method prunes `self.form_list` by checking if there is a condition method/value in `condition_list`. + If an entry exists, it will call/read the value and respect the result. If the condition returns False, the + form will be removed from `form_list`. """ - form_list = OrderedDict() - for form_key, form_class in self.form_list.items(): + for form_key in list(self.form_list.keys()): # try to fetch the value from condition list, by default, the form # gets passed to the new list. condition = self.condition_dict.get(form_key, True) if callable(condition): # call the value if needed, passes the current instance. condition = condition(self) - if condition: - form_list[form_key] = form_class - return form_list + if not condition: + del self.form_list[form_key] def dispatch(self, request, *args, **kwargs): """ @@ -241,6 +234,7 @@ def dispatch(self, request, *args, **kwargs): getattr(self, 'file_storage', None), ) self.steps = StepsHelper(self) + self.process_condition_dict() response = super().dispatch(request, *args, **kwargs) # update the response (e.g. adding cookies) @@ -273,7 +267,7 @@ def post(self, *args, **kwargs): # contains a valid step name. If one was found, render the requested # form. (This makes stepping back a lot easier). wizard_goto_step = self.request.POST.get('wizard_goto_step', None) - if wizard_goto_step and wizard_goto_step in self.get_form_list(): + if wizard_goto_step and wizard_goto_step in self.form_list: return self.render_goto_step(wizard_goto_step) # Check if form was refreshed @@ -342,7 +336,7 @@ def render_done(self, form, **kwargs): """ final_forms = OrderedDict() # walk through the form list and try to validate the data again. - for form_key in self.get_form_list(): + for form_key in self.form_list.keys(): form_obj = self.get_form( step=form_key, data=self.storage.get_step_data(form_key), @@ -406,7 +400,7 @@ def get_form(self, step=None, data=None, files=None): """ if step is None: step = self.steps.current - form_class = self.get_form_list()[step] + form_class = self.form_list[step] # prepare the kwargs for the form instance. kwargs = self.get_form_kwargs(step) kwargs.update({ @@ -469,7 +463,7 @@ def get_all_cleaned_data(self): 'formset-' and contain a list of the formset cleaned_data dictionaries. """ cleaned_data = {} - for form_key in self.get_form_list(): + for form_key in self.form_list.keys(): form_obj = self.get_form( step=form_key, data=self.storage.get_step_data(form_key), @@ -510,8 +504,7 @@ def get_next_step(self, step=None): """ if step is None: step = self.steps.current - form_list = self.get_form_list() - keys = list(form_list.keys()) + keys = list(self.form_list.keys()) if step not in keys: return self.steps.first key = keys.index(step) + 1 @@ -529,8 +522,7 @@ def get_prev_step(self, step=None): """ if step is None: step = self.steps.current - form_list = self.get_form_list() - keys = list(form_list.keys()) + keys = list(self.form_list.keys()) if step not in keys: return None key = keys.index(step) - 1 @@ -547,7 +539,7 @@ def get_step_index(self, step=None): """ if step is None: step = self.steps.current - keys = list(self.get_form_list().keys()) + keys = list(self.form_list.keys()) if step in keys: return keys.index(step) return None @@ -678,7 +670,7 @@ def get(self, *args, **kwargs): ) return self.render(form, **kwargs) - elif step_url in self.get_form_list(): + elif step_url in self.form_list: self.storage.current_step = step_url return self.render( self.get_form( @@ -699,7 +691,7 @@ def post(self, *args, **kwargs): is super'd from WizardView. """ wizard_goto_step = self.request.POST.get('wizard_goto_step', None) - if wizard_goto_step and wizard_goto_step in self.get_form_list(): + if wizard_goto_step and wizard_goto_step in self.form_list: return self.render_goto_step(wizard_goto_step) return super().post(*args, **kwargs) diff --git a/tests/wizard/test_forms.py b/tests/wizard/test_forms.py index 5711848..9f3f553 100644 --- a/tests/wizard/test_forms.py +++ b/tests/wizard/test_forms.py @@ -92,11 +92,12 @@ def done(self, form_list, **kwargs): class TestWizardWithCustomGetFormList(TestWizard): + form_list = [('start', Step1)] - form_list = [Step1] - - def get_form_list(self): - return {'start': Step1, 'step2': Step2} + def process_condition_dict(self): + super().process_condition_dict() + # Modify the `form_list` using any criteria (e.g. whether the user is logged in, etc.) or none at all + self.form_list['step2'] = Step2 class FormTests(TestCase): @@ -158,19 +159,6 @@ def test_form_condition(self): response, instance = testform(request) self.assertEqual(instance.get_next_step(), 'step2') - def test_form_condition_unstable(self): - request = get_request() - testform = TestWizard.as_view( - [('start', Step1), ('step2', Step2), ('step3', Step3)], - condition_dict={'step2': True} - ) - response, instance = testform(request) - self.assertEqual(instance.get_step_index('step2'), 1) - self.assertEqual(instance.get_next_step('step2'), 'step3') - instance.condition_dict['step2'] = False - self.assertEqual(instance.get_step_index('step2'), None) - self.assertEqual(instance.get_next_step('step2'), 'start') - def test_form_kwargs(self): request = get_request() testform = TestWizard.as_view([ @@ -265,23 +253,21 @@ def test_form_list_type(self): response, instance = testform(request) self.assertEqual(response.status_code, 200) - def test_get_form_list_default(self): + def test_form_list_default(self): request = get_request() testform = TestWizard.as_view([('start', Step1)]) response, instance = testform(request) - form_list = instance.get_form_list() - self.assertEqual(form_list, {'start': Step1}) + self.assertEqual(instance.form_list, {'start': Step1}) with self.assertRaises(KeyError): instance.get_form('step2') - def test_get_form_list_custom(self): + def test_form_list_custom(self): request = get_request() - testform = TestWizardWithCustomGetFormList.as_view([('start', Step1)]) + testform = TestWizardWithCustomGetFormList.as_view() response, instance = testform(request) - form_list = instance.get_form_list() - self.assertEqual(form_list, {'start': Step1, 'step2': Step2}) + self.assertEqual(instance.form_list, {'start': Step1, 'step2': Step2}) self.assertIsInstance(instance.get_form('step2'), Step2)