From 230e6d831cd4665376b2e9282576663b3d0f9bbd Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 5 Jan 2026 03:09:10 +0000 Subject: [PATCH 1/6] feat: pass invocation_state to model providers Enables custom model providers to access invocation_state from agent calls via kwargs. This supports use cases like custom request metadata, tracing context, and provider-specific configuration. Changes: - Add invocation_state parameter to stream_messages() - Pass invocation_state through event_loop to model.stream() - Update BedrockModel to extract and forward kwargs Custom providers can access via: kwargs.get('invocation_state') --- src/strands/event_loop/event_loop.py | 1 + src/strands/event_loop/streaming.py | 4 ++++ src/strands/models/bedrock.py | 10 ++++++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index fcb530a0d..231cfa56a 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -345,6 +345,7 @@ async def _handle_model_execution( tool_specs, system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, + invocation_state=invocation_state, ): yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 804f90a1d..df8b2dbe3 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -425,6 +425,7 @@ async def stream_messages( *, tool_choice: Optional[Any] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, + invocation_state: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -437,6 +438,7 @@ async def stream_messages( tool_choice: Optional tool choice constraint for forcing specific tool usage. system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. + invocation_state: Optional invocation state to pass to the model provider. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -453,6 +455,8 @@ async def stream_messages( system_prompt, tool_choice=tool_choice, system_prompt_content=system_prompt_content, + invocation_state=invocation_state, + **kwargs, ) async for event in process_stream(chunks, start_time): diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 08d8f400c..0e04cf312 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -193,6 +193,7 @@ def _format_request( tool_specs: Optional[list[ToolSpec]] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, tool_choice: ToolChoice | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format a Bedrock converse stream request. @@ -202,6 +203,7 @@ def _format_request( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments Returns: A Bedrock converse stream request. @@ -625,7 +627,9 @@ def callback(event: Optional[StreamEvent] = None) -> None: if system_prompt and system_prompt_content is None: system_prompt_content = [{"text": system_prompt}] - thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt_content, tool_choice) + thread = asyncio.to_thread( + self._stream, callback, messages, tool_specs, system_prompt_content, tool_choice, **kwargs + ) task = asyncio.create_task(thread) while True: @@ -644,6 +648,7 @@ def _stream( tool_specs: Optional[list[ToolSpec]] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, tool_choice: ToolChoice | None = None, + **kwargs: Any, ) -> None: """Stream conversation with the Bedrock model. @@ -656,6 +661,7 @@ def _stream( tool_specs: List of tool specifications to make available to the model. system_prompt_content: System prompt content blocks to provide context to the model. tool_choice: Selection strategy for tool invocation. + **kwargs: Additional keyword arguments Raises: ContextWindowOverflowException: If the input exceeds the model's context window. @@ -663,7 +669,7 @@ def _stream( """ try: logger.debug("formatting request") - request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice) + request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice, **kwargs) logger.debug("request=<%s>", request) logger.debug("invoking model") From 4cd7fdf985a5084753a52601a7f8ace1577ae803 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Tue, 6 Jan 2026 09:52:36 +0000 Subject: [PATCH 2/6] test: fix invocation_state deep copy in test fixtures --- tests/strands/agent/test_agent.py | 13 ++++++++++++- tests/strands/event_loop/test_event_loop.py | 1 + tests/strands/event_loop/test_streaming.py | 5 +++++ .../event_loop/test_streaming_structured_output.py | 2 ++ 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f133400a8..45ae74103 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -36,7 +36,14 @@ @pytest.fixture def mock_model(request): async def stream(*args, **kwargs): - result = mock.mock_stream(*copy.deepcopy(args), **copy.deepcopy(kwargs)) + # Deep copy args and kwargs, but skip invocation_state which may contain non-serializable objects + copied_args = copy.deepcopy(args) + copied_kwargs = { + key: value if key == 'invocation_state' else copy.deepcopy(value) + for key, value in kwargs.items() + } + + result = mock.mock_stream(*copied_args, **copied_kwargs) # If result is already an async generator, yield from it if hasattr(result, "__aiter__"): async for item in result: @@ -325,6 +332,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, ), unittest.mock.call( [ @@ -363,6 +371,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, ), ], ) @@ -484,6 +493,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -629,6 +639,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) assert conversation_manager_spy.reduce_context.call_count == 2 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 6b23bd592..639e60ea0 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -383,6 +383,7 @@ async def test_event_loop_cycle_tool_result( "p1", tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c6e44b78a..b2cc152cb 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -1117,6 +1117,7 @@ async def test_stream_messages(agenerator, alist): "test prompt", tool_choice=None, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) @@ -1150,6 +1151,7 @@ async def test_stream_messages_with_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1183,6 +1185,7 @@ async def test_stream_messages_single_text_block_backwards_compatibility(agenera "You are a helpful assistant.", tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1214,6 +1217,7 @@ async def test_stream_messages_empty_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=[], + invocation_state=None, ) @@ -1245,6 +1249,7 @@ async def test_stream_messages_none_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=None, + invocation_state=None, ) # Ensure that we're getting typed events coming out of process_stream diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py index 4645e1724..4c4082c00 100644 --- a/tests/strands/event_loop/test_streaming_structured_output.py +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -66,6 +66,7 @@ async def test_stream_messages_with_tool_choice(agenerator, alist): "test prompt", tool_choice=tool_choice, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) # Verify we get the expected events @@ -131,6 +132,7 @@ async def test_stream_messages_with_forced_structured_output(agenerator, alist): "Extract user information", tool_choice=tool_choice, system_prompt_content=[{"text": "Extract user information"}], + invocation_state=None, ) assert len(tru_events) > 0 From 0845b8841fee4b27a2e73d56979cc833f6cb8d6a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 12 Jan 2026 10:49:46 -0500 Subject: [PATCH 3/6] Revert "test: fix invocation_state deep copy in test fixtures" This reverts commit 4cd7fdf985a5084753a52601a7f8ace1577ae803. --- tests/strands/agent/test_agent.py | 13 +------------ tests/strands/event_loop/test_event_loop.py | 1 - tests/strands/event_loop/test_streaming.py | 5 ----- .../event_loop/test_streaming_structured_output.py | 2 -- 4 files changed, 1 insertion(+), 20 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 45ae74103..f133400a8 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -36,14 +36,7 @@ @pytest.fixture def mock_model(request): async def stream(*args, **kwargs): - # Deep copy args and kwargs, but skip invocation_state which may contain non-serializable objects - copied_args = copy.deepcopy(args) - copied_kwargs = { - key: value if key == 'invocation_state' else copy.deepcopy(value) - for key, value in kwargs.items() - } - - result = mock.mock_stream(*copied_args, **copied_kwargs) + result = mock.mock_stream(*copy.deepcopy(args), **copy.deepcopy(kwargs)) # If result is already an async generator, yield from it if hasattr(result, "__aiter__"): async for item in result: @@ -332,7 +325,6 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], - invocation_state=unittest.mock.ANY, ), unittest.mock.call( [ @@ -371,7 +363,6 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], - invocation_state=unittest.mock.ANY, ), ], ) @@ -493,7 +484,6 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, - invocation_state=unittest.mock.ANY, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -639,7 +629,6 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, - invocation_state=unittest.mock.ANY, ) assert conversation_manager_spy.reduce_context.call_count == 2 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 639e60ea0..6b23bd592 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -383,7 +383,6 @@ async def test_event_loop_cycle_tool_result( "p1", tool_choice=None, system_prompt_content=unittest.mock.ANY, - invocation_state=unittest.mock.ANY, ) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index b2cc152cb..c6e44b78a 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -1117,7 +1117,6 @@ async def test_stream_messages(agenerator, alist): "test prompt", tool_choice=None, system_prompt_content=[{"text": "test prompt"}], - invocation_state=None, ) @@ -1151,7 +1150,6 @@ async def test_stream_messages_with_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=system_prompt_content, - invocation_state=None, ) @@ -1185,7 +1183,6 @@ async def test_stream_messages_single_text_block_backwards_compatibility(agenera "You are a helpful assistant.", tool_choice=None, system_prompt_content=system_prompt_content, - invocation_state=None, ) @@ -1217,7 +1214,6 @@ async def test_stream_messages_empty_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=[], - invocation_state=None, ) @@ -1249,7 +1245,6 @@ async def test_stream_messages_none_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=None, - invocation_state=None, ) # Ensure that we're getting typed events coming out of process_stream diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py index 4c4082c00..4645e1724 100644 --- a/tests/strands/event_loop/test_streaming_structured_output.py +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -66,7 +66,6 @@ async def test_stream_messages_with_tool_choice(agenerator, alist): "test prompt", tool_choice=tool_choice, system_prompt_content=[{"text": "test prompt"}], - invocation_state=None, ) # Verify we get the expected events @@ -132,7 +131,6 @@ async def test_stream_messages_with_forced_structured_output(agenerator, alist): "Extract user information", tool_choice=tool_choice, system_prompt_content=[{"text": "Extract user information"}], - invocation_state=None, ) assert len(tru_events) > 0 From 3a4e6ac554e80b1d110ca5e946a0a38a1219ea2a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 12 Jan 2026 10:49:55 -0500 Subject: [PATCH 4/6] Revert "feat: pass invocation_state to model providers" This reverts commit 230e6d831cd4665376b2e9282576663b3d0f9bbd. --- src/strands/event_loop/event_loop.py | 1 - src/strands/event_loop/streaming.py | 4 ---- src/strands/models/bedrock.py | 10 ++-------- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 231cfa56a..fcb530a0d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -345,7 +345,6 @@ async def _handle_model_execution( tool_specs, system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, - invocation_state=invocation_state, ): yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index df8b2dbe3..804f90a1d 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -425,7 +425,6 @@ async def stream_messages( *, tool_choice: Optional[Any] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, - invocation_state: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -438,7 +437,6 @@ async def stream_messages( tool_choice: Optional tool choice constraint for forcing specific tool usage. system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. - invocation_state: Optional invocation state to pass to the model provider. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -455,8 +453,6 @@ async def stream_messages( system_prompt, tool_choice=tool_choice, system_prompt_content=system_prompt_content, - invocation_state=invocation_state, - **kwargs, ) async for event in process_stream(chunks, start_time): diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 0e04cf312..08d8f400c 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -193,7 +193,6 @@ def _format_request( tool_specs: Optional[list[ToolSpec]] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, tool_choice: ToolChoice | None = None, - **kwargs: Any, ) -> dict[str, Any]: """Format a Bedrock converse stream request. @@ -203,7 +202,6 @@ def _format_request( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks to provide context to the model. - **kwargs: Additional keyword arguments Returns: A Bedrock converse stream request. @@ -627,9 +625,7 @@ def callback(event: Optional[StreamEvent] = None) -> None: if system_prompt and system_prompt_content is None: system_prompt_content = [{"text": system_prompt}] - thread = asyncio.to_thread( - self._stream, callback, messages, tool_specs, system_prompt_content, tool_choice, **kwargs - ) + thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt_content, tool_choice) task = asyncio.create_task(thread) while True: @@ -648,7 +644,6 @@ def _stream( tool_specs: Optional[list[ToolSpec]] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, tool_choice: ToolChoice | None = None, - **kwargs: Any, ) -> None: """Stream conversation with the Bedrock model. @@ -661,7 +656,6 @@ def _stream( tool_specs: List of tool specifications to make available to the model. system_prompt_content: System prompt content blocks to provide context to the model. tool_choice: Selection strategy for tool invocation. - **kwargs: Additional keyword arguments Raises: ContextWindowOverflowException: If the input exceeds the model's context window. @@ -669,7 +663,7 @@ def _stream( """ try: logger.debug("formatting request") - request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice, **kwargs) + request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") From c854d25429daf171329cb5a1ab4ebaaf430a94fa Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 12 Jan 2026 11:10:41 -0500 Subject: [PATCH 5/6] feat(models): pass invocation_state to model providers --- src/strands/event_loop/streaming.py | 3 +++ src/strands/models/model.py | 2 ++ tests/strands/agent/test_agent.py | 4 ++++ tests/strands/event_loop/test_event_loop.py | 1 + tests/strands/event_loop/test_streaming.py | 5 +++++ tests/strands/event_loop/test_streaming_structured_output.py | 2 ++ 6 files changed, 17 insertions(+) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 804f90a1d..7840bfcef 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -425,6 +425,7 @@ async def stream_messages( *, tool_choice: Optional[Any] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, + invocation_state: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -437,6 +438,7 @@ async def stream_messages( tool_choice: Optional tool choice constraint for forcing specific tool usage. system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -453,6 +455,7 @@ async def stream_messages( system_prompt, tool_choice=tool_choice, system_prompt_content=system_prompt_content, + invocation_state=invocation_state, ) async for event in process_stream(chunks, start_time): diff --git a/src/strands/models/model.py b/src/strands/models/model.py index b2fa73802..6b7dd78d7 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -73,6 +73,7 @@ def stream( *, tool_choice: ToolChoice | None = None, system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: """Stream conversation with the model. @@ -89,6 +90,7 @@ def stream( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks for advanced features like caching. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f133400a8..d2a373b45 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -325,6 +325,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=None, ), unittest.mock.call( [ @@ -363,6 +364,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=None, ), ], ) @@ -484,6 +486,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=None, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -629,6 +632,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=None, ) assert conversation_manager_spy.reduce_context.call_count == 2 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 6b23bd592..05bf75eeb 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -383,6 +383,7 @@ async def test_event_loop_cycle_tool_result( "p1", tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=None, ) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c6e44b78a..b2cc152cb 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -1117,6 +1117,7 @@ async def test_stream_messages(agenerator, alist): "test prompt", tool_choice=None, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) @@ -1150,6 +1151,7 @@ async def test_stream_messages_with_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1183,6 +1185,7 @@ async def test_stream_messages_single_text_block_backwards_compatibility(agenera "You are a helpful assistant.", tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1214,6 +1217,7 @@ async def test_stream_messages_empty_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=[], + invocation_state=None, ) @@ -1245,6 +1249,7 @@ async def test_stream_messages_none_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=None, + invocation_state=None, ) # Ensure that we're getting typed events coming out of process_stream diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py index 4645e1724..4c4082c00 100644 --- a/tests/strands/event_loop/test_streaming_structured_output.py +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -66,6 +66,7 @@ async def test_stream_messages_with_tool_choice(agenerator, alist): "test prompt", tool_choice=tool_choice, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) # Verify we get the expected events @@ -131,6 +132,7 @@ async def test_stream_messages_with_forced_structured_output(agenerator, alist): "Extract user information", tool_choice=tool_choice, system_prompt_content=[{"text": "Extract user information"}], + invocation_state=None, ) assert len(tru_events) > 0 From 119203030792be0b0ce63416dc3d5b9ed9598e42 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 12 Jan 2026 11:16:22 -0500 Subject: [PATCH 6/6] feat(models): pass invocation_state to streaming from event_loop --- src/strands/event_loop/event_loop.py | 1 + tests/strands/agent/test_agent.py | 14 +++++++++----- tests/strands/event_loop/test_event_loop.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index fcb530a0d..231cfa56a 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -345,6 +345,7 @@ async def _handle_model_execution( tool_specs, system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, + invocation_state=invocation_state, ): yield event diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d2a373b45..351eadc84 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -36,7 +36,11 @@ @pytest.fixture def mock_model(request): async def stream(*args, **kwargs): - result = mock.mock_stream(*copy.deepcopy(args), **copy.deepcopy(kwargs)) + # Skip deep copy of invocation_state which contains non-serializable objects (agent, spans, etc.) + copied_kwargs = { + key: value if key == "invocation_state" else copy.deepcopy(value) for key, value in kwargs.items() + } + result = mock.mock_stream(*copy.deepcopy(args), **copied_kwargs) # If result is already an async generator, yield from it if hasattr(result, "__aiter__"): async for item in result: @@ -325,7 +329,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], - invocation_state=None, + invocation_state=unittest.mock.ANY, ), unittest.mock.call( [ @@ -364,7 +368,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], - invocation_state=None, + invocation_state=unittest.mock.ANY, ), ], ) @@ -486,7 +490,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, - invocation_state=None, + invocation_state=unittest.mock.ANY, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -632,7 +636,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, - invocation_state=None, + invocation_state=unittest.mock.ANY, ) assert conversation_manager_spy.reduce_context.call_count == 2 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 05bf75eeb..639e60ea0 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -383,7 +383,7 @@ async def test_event_loop_cycle_tool_result( "p1", tool_choice=None, system_prompt_content=unittest.mock.ANY, - invocation_state=None, + invocation_state=unittest.mock.ANY, )