diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index fcb530a0d..231cfa56a 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -345,6 +345,7 @@ async def _handle_model_execution( tool_specs, system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, + invocation_state=invocation_state, ): yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 804f90a1d..7840bfcef 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -425,6 +425,7 @@ async def stream_messages( *, tool_choice: Optional[Any] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, + invocation_state: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -437,6 +438,7 @@ async def stream_messages( tool_choice: Optional tool choice constraint for forcing specific tool usage. system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -453,6 +455,7 @@ async def stream_messages( system_prompt, tool_choice=tool_choice, system_prompt_content=system_prompt_content, + invocation_state=invocation_state, ) async for event in process_stream(chunks, start_time): diff --git a/src/strands/models/model.py b/src/strands/models/model.py index b2fa73802..6b7dd78d7 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -73,6 +73,7 @@ def stream( *, tool_choice: ToolChoice | None = None, system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: """Stream conversation with the model. @@ -89,6 +90,7 @@ def stream( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks for advanced features like caching. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f133400a8..351eadc84 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -36,7 +36,11 @@ @pytest.fixture def mock_model(request): async def stream(*args, **kwargs): - result = mock.mock_stream(*copy.deepcopy(args), **copy.deepcopy(kwargs)) + # Skip deep copy of invocation_state which contains non-serializable objects (agent, spans, etc.) + copied_kwargs = { + key: value if key == "invocation_state" else copy.deepcopy(value) for key, value in kwargs.items() + } + result = mock.mock_stream(*copy.deepcopy(args), **copied_kwargs) # If result is already an async generator, yield from it if hasattr(result, "__aiter__"): async for item in result: @@ -325,6 +329,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, ), unittest.mock.call( [ @@ -363,6 +368,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, ), ], ) @@ -484,6 +490,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -629,6 +636,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) assert conversation_manager_spy.reduce_context.call_count == 2 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 6b23bd592..639e60ea0 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -383,6 +383,7 @@ async def test_event_loop_cycle_tool_result( "p1", tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c6e44b78a..b2cc152cb 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -1117,6 +1117,7 @@ async def test_stream_messages(agenerator, alist): "test prompt", tool_choice=None, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) @@ -1150,6 +1151,7 @@ async def test_stream_messages_with_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1183,6 +1185,7 @@ async def test_stream_messages_single_text_block_backwards_compatibility(agenera "You are a helpful assistant.", tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1214,6 +1217,7 @@ async def test_stream_messages_empty_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=[], + invocation_state=None, ) @@ -1245,6 +1249,7 @@ async def test_stream_messages_none_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=None, + invocation_state=None, ) # Ensure that we're getting typed events coming out of process_stream diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py index 4645e1724..4c4082c00 100644 --- a/tests/strands/event_loop/test_streaming_structured_output.py +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -66,6 +66,7 @@ async def test_stream_messages_with_tool_choice(agenerator, alist): "test prompt", tool_choice=tool_choice, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) # Verify we get the expected events @@ -131,6 +132,7 @@ async def test_stream_messages_with_forced_structured_output(agenerator, alist): "Extract user information", tool_choice=tool_choice, system_prompt_content=[{"text": "Extract user information"}], + invocation_state=None, ) assert len(tru_events) > 0