diff --git a/sdks/python/ag_ui/core/events.py b/sdks/python/ag_ui/core/events.py index 94fb63c75..27432153e 100644 --- a/sdks/python/ag_ui/core/events.py +++ b/sdks/python/ag_ui/core/events.py @@ -279,11 +279,16 @@ class StepFinishedEvent(BaseEvent): TextMessageContentEvent, TextMessageEndEvent, TextMessageChunkEvent, + ThinkingTextMessageStartEvent, + ThinkingTextMessageContentEvent, + ThinkingTextMessageEndEvent, ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallChunkEvent, ToolCallResultEvent, + ThinkingStartEvent, + ThinkingEndEvent, StateSnapshotEvent, StateDeltaEvent, MessagesSnapshotEvent, diff --git a/sdks/python/tests/test_events.py b/sdks/python/tests/test_events.py index 7245d88aa..68fb483fb 100644 --- a/sdks/python/tests/test_events.py +++ b/sdks/python/tests/test_events.py @@ -1,8 +1,10 @@ import unittest import json +import typing from datetime import datetime from pydantic import ValidationError, TypeAdapter +from ag_ui.core import events as events_module from ag_ui.core.types import Message, UserMessage, AssistantMessage, FunctionCall, ToolCall from ag_ui.core.events import ( EventType, @@ -596,6 +598,31 @@ def test_event_with_unicode_and_special_chars(self): # Verify Unicode and special characters are preserved self.assertEqual(deserialized.delta, text) + def test_all_event_subclasses_in_event_union(self): + """Ensure all BaseEvent subclasses are included in the Event union type""" + + # Get all classes defined in the events module that are subclasses of BaseEvent + event_subclasses = set() + for name in dir(events_module): + obj = getattr(events_module, name) + if ( + isinstance(obj, type) + and issubclass(obj, BaseEvent) + and obj is not BaseEvent + ): + event_subclasses.add(obj) + + # Get all types in the Event union + union_types = set(typing.get_args(typing.get_args(Event)[0])) + + # Check that all event subclasses are in the union + missing_from_union = event_subclasses - union_types + self.assertEqual( + missing_from_union, + set(), + f"The following event types are missing from the Event union: {missing_from_union}" + ) + if __name__ == "__main__": unittest.main()