diff --git a/tests/test_multiturn_env.py b/tests/test_multiturn_env.py index c2acbd00f..7ab009eaa 100644 --- a/tests/test_multiturn_env.py +++ b/tests/test_multiturn_env.py @@ -181,6 +181,43 @@ async def test_state_initialization(self, mock_multiturn_env): assert "responses" in state assert isinstance(state["responses"], list) + @pytest.mark.asyncio + async def test_state_initialization_allows_prompt_rewrite(self, mock_multiturn_env): + """Test that `set_state` allows for prompt rewriting""" + mock_multiturn_env.client.add_chat_response( + messages=[{"role": "user", "content": "Test state"}], response="Quick DONE" + ) + + prompt = [{"role": "user", "content": "Test state"}] + mutated_prompt = [{"role": "user", "content": "Mutated test state"}] + + def mutate_prompt(state): + state["prompt"][0]["content"] = "Mutated test state" + return state + + mock_multiturn_env.setup_state = mutate_prompt + + completion, state = await mock_multiturn_env.rollout( + client=mock_multiturn_env.client, + model="test-model", + prompt=prompt, + answer="test_answer", + task="test_task", + info={"extra": "data"}, + ) + + # Check prompt was mutated + assert state["prompt"] == mutated_prompt + + # Check rest of state fields are initialized + # state["completion"] is initialized to [] but not updated during rollout + assert state["completion"] == [] + assert state["answer"] == "test_answer" + assert state["task"] == "test_task" + assert state["info"] == {"extra": "data"} + assert "responses" in state + assert isinstance(state["responses"], list) + @pytest.mark.asyncio async def test_immediate_completion(self, mock_multiturn_env): """Test completion detection on first turn.""" diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py index 631376854..b54c8289a 100644 --- a/verifiers/envs/multiturn_env.py +++ b/verifiers/envs/multiturn_env.py @@ -82,11 +82,16 @@ async def rollout( } start_time = time.time() state = await maybe_await(self.setup_state, state, **kwargs) + if self.message_type == "chat": assert isinstance(prompt, list) + if isinstance(state["prompt"], list): + prompt = state["prompt"] completion = [] else: assert isinstance(prompt, str) + if isinstance(state["prompt"], str): + prompt = state["prompt"] completion = "" state["responses_start_idx"] = [] rollout = list(prompt) if not isinstance(prompt, str) else prompt