Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/test_multiturn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,43 @@ async def test_state_initialization(self, mock_multiturn_env):
assert "responses" in state
assert isinstance(state["responses"], list)

@pytest.mark.asyncio
async def test_state_initialization_allows_prompt_rewrite(self, mock_multiturn_env):
"""Test that `set_state` allows for prompt rewriting"""
mock_multiturn_env.client.add_chat_response(
messages=[{"role": "user", "content": "Test state"}], response="Quick DONE"
)

prompt = [{"role": "user", "content": "Test state"}]
mutated_prompt = [{"role": "user", "content": "Mutated test state"}]

def mutate_prompt(state):
state["prompt"][0]["content"] = "Mutated test state"
return state

mock_multiturn_env.setup_state = mutate_prompt

completion, state = await mock_multiturn_env.rollout(
client=mock_multiturn_env.client,
model="test-model",
prompt=prompt,
answer="test_answer",
task="test_task",
info={"extra": "data"},
)

# Check prompt was mutated
assert state["prompt"] == mutated_prompt

# Check rest of state fields are initialized
# state["completion"] is initialized to [] but not updated during rollout
assert state["completion"] == []
assert state["answer"] == "test_answer"
assert state["task"] == "test_task"
assert state["info"] == {"extra": "data"}
assert "responses" in state
assert isinstance(state["responses"], list)

@pytest.mark.asyncio
async def test_immediate_completion(self, mock_multiturn_env):
"""Test completion detection on first turn."""
Expand Down
5 changes: 5 additions & 0 deletions verifiers/envs/multiturn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,16 @@ async def rollout(
}
start_time = time.time()
state = await maybe_await(self.setup_state, state, **kwargs)

if self.message_type == "chat":
assert isinstance(prompt, list)
if isinstance(state["prompt"], list):
prompt = state["prompt"]
completion = []
else:
assert isinstance(prompt, str)
if isinstance(state["prompt"], str):
prompt = state["prompt"]
completion = ""
state["responses_start_idx"] = []
rollout = list(prompt) if not isinstance(prompt, str) else prompt
Expand Down
Loading