diff --git a/laboratory/experiment.py b/laboratory/experiment.py index a76d642..2cb7d41 100644 --- a/laboratory/experiment.py +++ b/laboratory/experiment.py @@ -108,13 +108,15 @@ def candidate(self, cand_func, args=None, kwargs=None, name='Candidate', context 'context': context or {}, }) - def conduct(self, randomize=True): + def conduct(self, randomize=True, candidates_first=False): ''' Run control & candidate functions and return the control's return value. ``control()`` must be called first. :param bool randomize: controls whether we shuffle the order of execution between control and candidate + :param bool candidates_first: whether to run the candidates before the + control :raise LaboratoryException: when no control case has been set :return: Control function's return value ''' @@ -135,13 +137,18 @@ def get_func_executor(obs_def, is_control): """A lightweight wrapper around a tested function in order to retrieve state""" return lambda *a, **kw: (self._run_tested_func(raise_on_exception=is_control, **obs_def), is_control) - funcs = [ - get_func_executor(self._control, is_control=True), - ] + [get_func_executor(cand, is_control=False,) for cand in self._candidates] + control_func = get_func_executor(self._control, is_control=True) + funcs = [get_func_executor(cand, is_control=False,) for cand in self._candidates] if randomize: random.shuffle(funcs) + # Insert the control func at a random index if randomize and candidates are not run first + control_index = random.randint(0, len(funcs)) if randomize else 0 + if candidates_first: + control_index = len(funcs) + funcs.insert(control_index, control_func) + control = None candidates = [] diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 6663980..dbe665c 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -173,3 +173,29 @@ def control_func(): control_indexes = [run_experiment() for i in range(5)] assert set(control_indexes) == set([0]) + + +@pytest.mark.parametrize('randomize', [True, False]) +def test_candidates_first_executes_control_last(randomize): + num_candidates = 100 + + def run_experiment(): + exp = laboratory.Experiment() + + counter = {'index': 0} + def increment_counter(): + counter['index'] += 1 + + def control_func(): + return counter['index'] + + cand_func = mock.Mock(side_effect=increment_counter) + + exp.control(control_func) + for _ in range(num_candidates): + exp.candidate(cand_func) + + return exp.conduct(randomize=randomize, candidates_first=True) + + control_indexes = [run_experiment() for i in range(5)] + assert set(control_indexes) == set([num_candidates])