From a720f63b4dcebf466b0675bd2f7a2985b753ef40 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Tue, 29 Jul 2025 14:11:22 -0400 Subject: [PATCH] Realtime: enable a playback tracker --- src/agents/realtime/__init__.py | 4 + src/agents/realtime/_default_tracker.py | 47 +++++++++ src/agents/realtime/_util.py | 9 ++ src/agents/realtime/model.py | 94 ++++++++++++++++++ src/agents/realtime/openai_realtime.py | 82 ++++++++++------ tests/realtime/test_agent.py | 2 + tests/realtime/test_conversion_helpers.py | 2 + tests/realtime/test_openai_realtime.py | 83 +++++++++------- tests/realtime/test_playback_tracker.py | 112 ++++++++++++++++++++++ 9 files changed, 376 insertions(+), 59 deletions(-) create mode 100644 src/agents/realtime/_default_tracker.py create mode 100644 src/agents/realtime/_util.py create mode 100644 tests/realtime/test_playback_tracker.py diff --git a/src/agents/realtime/__init__.py b/src/agents/realtime/__init__.py index 49c131389..7675c466f 100644 --- a/src/agents/realtime/__init__.py +++ b/src/agents/realtime/__init__.py @@ -47,6 +47,8 @@ RealtimeModel, RealtimeModelConfig, RealtimeModelListener, + RealtimePlaybackState, + RealtimePlaybackTracker, ) from .model_events import ( RealtimeConnectionStatus, @@ -139,6 +141,8 @@ "RealtimeModel", "RealtimeModelConfig", "RealtimeModelListener", + "RealtimePlaybackTracker", + "RealtimePlaybackState", # Model Events "RealtimeConnectionStatus", "RealtimeModelAudioDoneEvent", diff --git a/src/agents/realtime/_default_tracker.py b/src/agents/realtime/_default_tracker.py new file mode 100644 index 000000000..49bc827c2 --- /dev/null +++ b/src/agents/realtime/_default_tracker.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime + +from ._util import calculate_audio_length_ms +from .config import RealtimeAudioFormat + + +@dataclass +class ModelAudioState: + initial_received_time: datetime + audio_length_ms: float + + +class ModelAudioTracker: + def __init__(self) -> None: + # (item_id, item_content_index) -> ModelAudioState + self._states: dict[tuple[str, int], ModelAudioState] = {} + self._last_audio_item: tuple[str, int] | None = None + + def set_audio_format(self, format: RealtimeAudioFormat) -> None: + """Called when the model wants to set the audio format.""" + self._format = format + + def on_audio_delta(self, item_id: str, item_content_index: int, audio_bytes: bytes) -> None: + """Called when an audio delta is received from the model.""" + ms = calculate_audio_length_ms(self._format, audio_bytes) + new_key = (item_id, item_content_index) + + self._last_audio_item = new_key + if new_key not in self._states: + self._states[new_key] = ModelAudioState(datetime.now(), ms) + else: + self._states[new_key].audio_length_ms += ms + + def on_interrupted(self) -> None: + """Called when the audio playback has been interrupted.""" + self._last_audio_item = None + + def get_state(self, item_id: str, item_content_index: int) -> ModelAudioState | None: + """Called when the model wants to get the current playback state.""" + return self._states.get((item_id, item_content_index)) + + def get_last_audio_item(self) -> tuple[str, int] | None: + """Called when the model wants to get the last audio item ID and content index.""" + return self._last_audio_item diff --git a/src/agents/realtime/_util.py b/src/agents/realtime/_util.py new file mode 100644 index 000000000..c8926edfb --- /dev/null +++ b/src/agents/realtime/_util.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from .config import RealtimeAudioFormat + + +def calculate_audio_length_ms(format: RealtimeAudioFormat | None, audio_bytes: bytes) -> float: + if format and format.startswith("g711"): + return (len(audio_bytes) / 8000) * 1000 + return (len(audio_bytes) / 24 / 2) * 1000 diff --git a/src/agents/realtime/model.py b/src/agents/realtime/model.py index e279ecc95..d7ebe4ffa 100644 --- a/src/agents/realtime/model.py +++ b/src/agents/realtime/model.py @@ -6,13 +6,95 @@ from typing_extensions import NotRequired, TypedDict from ..util._types import MaybeAwaitable +from ._util import calculate_audio_length_ms from .config import ( + RealtimeAudioFormat, RealtimeSessionModelSettings, ) from .model_events import RealtimeModelEvent from .model_inputs import RealtimeModelSendEvent +class RealtimePlaybackState(TypedDict): + current_item_id: str | None + """The item ID of the current item being played.""" + + current_item_content_index: int | None + """The index of the current item content being played.""" + + elapsed_ms: float | None + """The number of milliseconds of audio that have been played.""" + + +class RealtimePlaybackTracker: + """If you have custom playback logic or expect that audio is played with delays or at different + speeds, create an instance of RealtimePlaybackTracker and pass it to the session. You are + responsible for tracking the audio playback progress and calling `on_play_bytes` or + `on_play_ms` when the user has played some audio.""" + + def __init__(self) -> None: + self._format: RealtimeAudioFormat | None = None + # (item_id, item_content_index) + self._current_item: tuple[str, int] | None = None + self._elapsed_ms: float | None = None + + def on_play_bytes(self, item_id: str, item_content_index: int, bytes: bytes) -> None: + """Called by you when you have played some audio. + + Args: + item_id: The item ID of the audio being played. + item_content_index: The index of the audio content in `item.content` + bytes: The audio bytes that have been fully played. + """ + ms = calculate_audio_length_ms(self._format, bytes) + self.on_play_ms(item_id, item_content_index, ms) + + def on_play_ms(self, item_id: str, item_content_index: int, ms: float) -> None: + """Called by you when you have played some audio. + + Args: + item_id: The item ID of the audio being played. + item_content_index: The index of the audio content in `item.content` + ms: The number of milliseconds of audio that have been played. + """ + if self._current_item != (item_id, item_content_index): + self._current_item = (item_id, item_content_index) + self._elapsed_ms = ms + else: + assert self._elapsed_ms is not None + self._elapsed_ms += ms + + def on_interrupted(self) -> None: + """Called by the model when the audio playback has been interrupted.""" + self._current_item = None + self._elapsed_ms = None + + def set_audio_format(self, format: RealtimeAudioFormat) -> None: + """Will be called by the model to set the audio format. + + Args: + format: The audio format to use. + """ + self._format = format + + def get_state(self) -> RealtimePlaybackState: + """Will be called by the model to get the current playback state.""" + if self._current_item is None: + return { + "current_item_id": None, + "current_item_content_index": None, + "elapsed_ms": None, + } + assert self._elapsed_ms is not None + + item_id, item_content_index = self._current_item + return { + "current_item_id": item_id, + "current_item_content_index": item_content_index, + "elapsed_ms": self._elapsed_ms, + } + + class RealtimeModelListener(abc.ABC): """A listener for realtime transport events.""" @@ -39,6 +121,18 @@ class RealtimeModelConfig(TypedDict): initial_model_settings: NotRequired[RealtimeSessionModelSettings] """The initial model settings to use when connecting.""" + playback_tracker: NotRequired[RealtimePlaybackTracker] + """The playback tracker to use when tracking audio playback progress. If not set, the model will + use a default implementation that assumes audio is played immediately, at realtime speed. + + A playback tracker is useful for interruptions. The model generates audio much faster than + realtime playback speed. So if there's an interruption, its useful for the model to know how + much of the audio has been played by the user. In low-latency scenarios, it's fine to assume + that audio is played back immediately at realtime speed. But in scenarios like phone calls or + other remote interactions, you can set a playback tracker that lets the model know when audio + is played to the user. + """ + class RealtimeModel(abc.ABC): """Interface for connecting to a realtime model and sending/receiving events.""" diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index ab9408dfe..2e950b0a1 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -57,6 +57,7 @@ from websockets.asyncio.client import ClientConnection from agents.handoffs import Handoff +from agents.realtime._default_tracker import ModelAudioTracker from agents.tool import FunctionTool, Tool from agents.util._types import MaybeAwaitable @@ -72,6 +73,8 @@ RealtimeModel, RealtimeModelConfig, RealtimeModelListener, + RealtimePlaybackState, + RealtimePlaybackTracker, ) from .model_events import ( RealtimeModelAudioDoneEvent, @@ -133,11 +136,10 @@ def __init__(self) -> None: self._websocket_task: asyncio.Task[None] | None = None self._listeners: list[RealtimeModelListener] = [] self._current_item_id: str | None = None - self._audio_start_time: datetime | None = None - self._audio_length_ms: float = 0.0 + self._audio_state_tracker: ModelAudioTracker = ModelAudioTracker() self._ongoing_response: bool = False - self._current_audio_content_index: int | None = None self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None + self._playback_tracker: RealtimePlaybackTracker | None = None async def connect(self, options: RealtimeModelConfig) -> None: """Establish a connection to the model and keep it alive.""" @@ -146,6 +148,8 @@ async def connect(self, options: RealtimeModelConfig) -> None: model_settings: RealtimeSessionModelSettings = options.get("initial_model_settings", {}) + self._playback_tracker = options.get("playback_tracker", RealtimePlaybackTracker()) + self.model = model_settings.get("model_name", self.model) api_key = await get_api_key(options.get("api_key")) @@ -294,31 +298,62 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None: if event.start_response: await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create")) + def _get_playback_state(self) -> RealtimePlaybackState: + if self._playback_tracker: + return self._playback_tracker.get_state() + + if last_audio_item_id := self._audio_state_tracker.get_last_audio_item(): + item_id, item_content_index = last_audio_item_id + audio_state = self._audio_state_tracker.get_state(item_id, item_content_index) + if audio_state: + elapsed_ms = ( + datetime.now() - audio_state.initial_received_time + ).total_seconds() * 1000 + return { + "current_item_id": item_id, + "current_item_content_index": item_content_index, + "elapsed_ms": elapsed_ms, + } + + return { + "current_item_id": None, + "current_item_content_index": None, + "elapsed_ms": None, + } + async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None: - if not self._current_item_id or not self._audio_start_time: + playback_state = self._get_playback_state() + current_item_id = playback_state.get("current_item_id") + current_item_content_index = playback_state.get("current_item_content_index") + elapsed_ms = playback_state.get("elapsed_ms") + if current_item_id is None or elapsed_ms is None: + logger.info( + "Skipping interrupt. " + f"Item id: {current_item_id}, " + f"elapsed ms: {elapsed_ms}, " + f"content index: {current_item_content_index}" + ) return - await self._cancel_response() - - elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000 - if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms: + current_item_content_index = current_item_content_index or 0 + if elapsed_ms > 0: await self._emit_event( RealtimeModelAudioInterruptedEvent( - item_id=self._current_item_id, - content_index=self._current_audio_content_index or 0, + item_id=current_item_id, + content_index=current_item_content_index, ) ) converted = _ConversionHelper.convert_interrupt( - self._current_item_id, - self._current_audio_content_index or 0, - int(elapsed_time_ms), + current_item_id, + current_item_content_index, + int(elapsed_ms), ) await self._send_raw_message(converted) + await self._cancel_response() - self._current_item_id = None - self._audio_start_time = None - self._audio_length_ms = 0.0 - self._current_audio_content_index = None + self._audio_state_tracker.on_interrupted() + if self._playback_tracker: + self._playback_tracker.on_interrupted() async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None: """Send a session update to the model.""" @@ -326,15 +361,12 @@ async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> N async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None: """Handle audio delta events and update audio tracking state.""" - self._current_audio_content_index = parsed.content_index self._current_item_id = parsed.item_id - if self._audio_start_time is None: - self._audio_start_time = datetime.now() - self._audio_length_ms = 0.0 audio_bytes = base64.b64decode(parsed.delta) - # Calculate audio length in ms using 24KHz pcm16le - self._audio_length_ms += self._calculate_audio_length_ms(audio_bytes) + + self._audio_state_tracker.on_audio_delta(parsed.item_id, parsed.content_index, audio_bytes) + await self._emit_event( RealtimeModelAudioEvent( data=audio_bytes, @@ -344,10 +376,6 @@ async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None: ) ) - def _calculate_audio_length_ms(self, audio_bytes: bytes) -> float: - """Calculate audio length in milliseconds for 24KHz PCM16LE format.""" - return len(audio_bytes) / 24 / 2 - async def _handle_output_item(self, item: ConversationItem) -> None: """Handle response output item events (function calls and messages).""" if item.type == "function_call" and item.status == "completed": diff --git a/tests/realtime/test_agent.py b/tests/realtime/test_agent.py index aae8bc47c..7f1dc3ea3 100644 --- a/tests/realtime/test_agent.py +++ b/tests/realtime/test_agent.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from agents import RunContextWrapper diff --git a/tests/realtime/test_conversion_helpers.py b/tests/realtime/test_conversion_helpers.py index 859813edd..2d84c8c49 100644 --- a/tests/realtime/test_conversion_helpers.py +++ b/tests/realtime/test_conversion_helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 from unittest.mock import Mock diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 5cb0eb0fa..704d95d40 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -1,4 +1,3 @@ -from datetime import datetime from typing import Any from unittest.mock import AsyncMock, Mock, patch @@ -226,6 +225,9 @@ async def test_handle_audio_delta_event_success(self, model): mock_listener = AsyncMock() model.add_listener(mock_listener) + # Set up audio format on the tracker before testing + model._audio_state_tracker.set_audio_format("pcm16") + # Valid audio delta event (minimal required fields for OpenAI spec) audio_event = { "type": "response.audio.delta", @@ -237,23 +239,22 @@ async def test_handle_audio_delta_event_success(self, model): "delta": "dGVzdCBhdWRpbw==", # base64 encoded "test audio" } - with patch("agents.realtime.openai_realtime.datetime") as mock_datetime: - mock_now = datetime(2024, 1, 1, 12, 0, 0) - mock_datetime.now.return_value = mock_now + await model._handle_ws_event(audio_event) - await model._handle_ws_event(audio_event) + # Should emit audio event to listeners + mock_listener.on_event.assert_called_once() + emitted_event = mock_listener.on_event.call_args[0][0] + assert isinstance(emitted_event, RealtimeModelAudioEvent) + assert emitted_event.response_id == "resp_123" + assert emitted_event.data == b"test audio" # decoded from base64 - # Should emit audio event to listeners - mock_listener.on_event.assert_called_once() - emitted_event = mock_listener.on_event.call_args[0][0] - assert isinstance(emitted_event, RealtimeModelAudioEvent) - assert emitted_event.response_id == "resp_123" - assert emitted_event.data == b"test audio" # decoded from base64 + # Should update internal audio tracking state + assert model._current_item_id == "item_456" - # Should update internal audio tracking state - assert model._current_item_id == "item_456" - assert model._current_audio_content_index == 0 - assert model._audio_start_time == mock_now + # Test that audio state is tracked in the tracker + audio_state = model._audio_state_tracker.get_state("item_456", 0) + assert audio_state is not None + assert audio_state.audio_length_ms > 0 # Should have some audio length @pytest.mark.asyncio async def test_handle_error_event_success(self, model): @@ -319,6 +320,9 @@ async def test_audio_timing_calculation_accuracy(self, model): mock_listener = AsyncMock() model.add_listener(mock_listener) + # Set up audio format on the tracker before testing + model._audio_state_tracker.set_audio_format("pcm16") + # Send multiple audio deltas to test cumulative timing audio_deltas = [ { @@ -344,21 +348,34 @@ async def test_audio_timing_calculation_accuracy(self, model): for event in audio_deltas: await model._handle_ws_event(event) - # Should accumulate audio length: 8 bytes / 24 / 2 = ~0.167ms per byte - # Total: 8 bytes / 24 / 2 = 0.167ms - expected_length = 8 / 24 / 2 - assert abs(model._audio_length_ms - expected_length) < 0.001 + # Should accumulate audio length: 8 bytes / 24 / 2 * 1000 = milliseconds + # Total: 8 bytes / 24 / 2 * 1000 + expected_length = (8 / 24 / 2) * 1000 + + # Test through the actual audio state tracker + audio_state = model._audio_state_tracker.get_state("item_1", 0) + assert audio_state is not None + assert abs(audio_state.audio_length_ms - expected_length) < 0.001 def test_calculate_audio_length_ms_pure_function(self, model): """Test the pure audio length calculation function.""" - # Test various audio buffer sizes - assert model._calculate_audio_length_ms(b"test") == 4 / 24 / 2 # 4 bytes - assert model._calculate_audio_length_ms(b"") == 0 # empty - assert model._calculate_audio_length_ms(b"a" * 48) == 1.0 # exactly 1ms worth + from agents.realtime._util import calculate_audio_length_ms + + # Test various audio buffer sizes for pcm16 format + assert calculate_audio_length_ms("pcm16", b"test") == (4 / 24 / 2) * 1000 # 4 bytes + assert calculate_audio_length_ms("pcm16", b"") == 0 # empty + assert calculate_audio_length_ms("pcm16", b"a" * 48) == 1000.0 # exactly 1000ms worth + + # Test g711 format + assert calculate_audio_length_ms("g711_ulaw", b"test") == (4 / 8000) * 1000 # 4 bytes + assert calculate_audio_length_ms("g711_alaw", b"a" * 8) == (8 / 8000) * 1000 # 8 bytes @pytest.mark.asyncio async def test_handle_audio_delta_state_management(self, model): """Test that _handle_audio_delta properly manages internal state.""" + # Set up audio format on the tracker before testing + model._audio_state_tracker.set_audio_format("pcm16") + # Create mock parsed event mock_parsed = Mock() mock_parsed.content_index = 5 @@ -366,14 +383,16 @@ async def test_handle_audio_delta_state_management(self, model): mock_parsed.delta = "dGVzdA==" # "test" in base64 mock_parsed.response_id = "resp_123" - with patch("agents.realtime.openai_realtime.datetime") as mock_datetime: - mock_now = datetime(2024, 1, 1, 12, 0, 0) - mock_datetime.now.return_value = mock_now + await model._handle_audio_delta(mock_parsed) + + # Check state was updated correctly + assert model._current_item_id == "test_item" - await model._handle_audio_delta(mock_parsed) + # Test that audio state is tracked correctly + audio_state = model._audio_state_tracker.get_state("test_item", 5) + assert audio_state is not None + assert audio_state.audio_length_ms == (4 / 24 / 2) * 1000 # 4 bytes in milliseconds - # Check state was updated correctly - assert model._current_audio_content_index == 5 - assert model._current_item_id == "test_item" - assert model._audio_start_time == mock_now - assert model._audio_length_ms == 4 / 24 / 2 # 4 bytes + # Test that last audio item is tracked + last_item = model._audio_state_tracker.get_last_audio_item() + assert last_item == ("test_item", 5) diff --git a/tests/realtime/test_playback_tracker.py b/tests/realtime/test_playback_tracker.py new file mode 100644 index 000000000..c0bfba468 --- /dev/null +++ b/tests/realtime/test_playback_tracker.py @@ -0,0 +1,112 @@ +from unittest.mock import AsyncMock + +import pytest + +from agents.realtime._default_tracker import ModelAudioTracker +from agents.realtime.model import RealtimePlaybackTracker +from agents.realtime.model_inputs import RealtimeModelSendInterrupt +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel + + +class TestPlaybackTracker: + """Test playback tracker functionality for interrupt timing.""" + + @pytest.fixture + def model(self): + """Create a fresh model instance for each test.""" + return OpenAIRealtimeWebSocketModel() + + @pytest.mark.asyncio + async def test_interrupt_timing_with_custom_playback_tracker(self, model): + """Test interrupt uses custom playback tracker elapsed time instead of default timing.""" + + # Create custom tracker and set elapsed time + custom_tracker = RealtimePlaybackTracker() + custom_tracker.set_audio_format("pcm16") + custom_tracker.on_play_ms("item_1", 1, 500.0) # content_index 1, 500ms played + + # Set up model with custom tracker directly + model._playback_tracker = custom_tracker + + # Mock send_raw_message to capture interrupt + model._send_raw_message = AsyncMock() + + # Send interrupt + + await model._send_interrupt(RealtimeModelSendInterrupt()) + + # Should use custom tracker's 500ms elapsed time + model._send_raw_message.assert_called_once() + call_args = model._send_raw_message.call_args[0][0] + assert call_args.audio_end_ms == 500 + + @pytest.mark.asyncio + async def test_interrupt_skipped_when_no_audio_playing(self, model): + """Test interrupt returns early when no audio is currently playing.""" + model._send_raw_message = AsyncMock() + + # No audio playing (default state) + + await model._send_interrupt(RealtimeModelSendInterrupt()) + + # Should not send any interrupt message + model._send_raw_message.assert_not_called() + + def test_audio_state_accumulation_across_deltas(self): + """Test ModelAudioTracker accumulates audio length across multiple deltas.""" + + tracker = ModelAudioTracker() + tracker.set_audio_format("pcm16") + + # Send multiple deltas for same item + tracker.on_audio_delta("item_1", 0, b"test") # 4 bytes + tracker.on_audio_delta("item_1", 0, b"more") # 4 bytes + + state = tracker.get_state("item_1", 0) + assert state is not None + # Should accumulate: 8 bytes / 24 / 2 * 1000 = 166.67ms + expected_length = (8 / 24 / 2) * 1000 + assert abs(state.audio_length_ms - expected_length) < 0.01 + + def test_state_cleanup_on_interruption(self): + """Test both trackers properly reset state on interruption.""" + + # Test ModelAudioTracker cleanup + model_tracker = ModelAudioTracker() + model_tracker.set_audio_format("pcm16") + model_tracker.on_audio_delta("item_1", 0, b"test") + assert model_tracker.get_last_audio_item() == ("item_1", 0) + + model_tracker.on_interrupted() + assert model_tracker.get_last_audio_item() is None + + # Test RealtimePlaybackTracker cleanup + playback_tracker = RealtimePlaybackTracker() + playback_tracker.on_play_ms("item_1", 0, 100.0) + + state = playback_tracker.get_state() + assert state["current_item_id"] == "item_1" + assert state["elapsed_ms"] == 100.0 + + playback_tracker.on_interrupted() + state = playback_tracker.get_state() + assert state["current_item_id"] is None + assert state["elapsed_ms"] is None + + def test_audio_length_calculation_with_different_formats(self): + """Test calculate_audio_length_ms handles g711 and PCM formats correctly.""" + from agents.realtime._util import calculate_audio_length_ms + + # Test g711 format (8kHz) + g711_bytes = b"12345678" # 8 bytes + g711_length = calculate_audio_length_ms("g711_ulaw", g711_bytes) + assert g711_length == 1 # (8 / 8000) * 1000 + + # Test PCM format (24kHz, default) + pcm_bytes = b"test" # 4 bytes + pcm_length = calculate_audio_length_ms("pcm16", pcm_bytes) + assert pcm_length == (4 / 24 / 2) * 1000 # ~83.33ms + + # Test None format (defaults to PCM) + none_length = calculate_audio_length_ms(None, pcm_bytes) + assert none_length == pcm_length