Skip to content

Realtime: enable a playback tracker #1242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/agents/realtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
RealtimeModel,
RealtimeModelConfig,
RealtimeModelListener,
RealtimePlaybackState,
RealtimePlaybackTracker,
)
from .model_events import (
RealtimeConnectionStatus,
Expand Down Expand Up @@ -139,6 +141,8 @@
"RealtimeModel",
"RealtimeModelConfig",
"RealtimeModelListener",
"RealtimePlaybackTracker",
"RealtimePlaybackState",
# Model Events
"RealtimeConnectionStatus",
"RealtimeModelAudioDoneEvent",
Expand Down
47 changes: 47 additions & 0 deletions src/agents/realtime/_default_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime

from ._util import calculate_audio_length_ms
from .config import RealtimeAudioFormat


@dataclass
class ModelAudioState:
initial_received_time: datetime
audio_length_ms: float


class ModelAudioTracker:
def __init__(self) -> None:
# (item_id, item_content_index) -> ModelAudioState
self._states: dict[tuple[str, int], ModelAudioState] = {}
self._last_audio_item: tuple[str, int] | None = None

def set_audio_format(self, format: RealtimeAudioFormat) -> None:
"""Called when the model wants to set the audio format."""
self._format = format

def on_audio_delta(self, item_id: str, item_content_index: int, bytes: bytes) -> None:
"""Called when an audio delta is received from the model."""
ms = calculate_audio_length_ms(self._format, bytes)
new_key = (item_id, item_content_index)

self._last_audio_item = new_key
if new_key not in self._states:
self._states[new_key] = ModelAudioState(datetime.now(), ms)
else:
self._states[new_key].audio_length_ms += ms

def on_interrupted(self) -> None:
"""Called when the audio playback has been interrupted."""
self._last_audio_item = None

def get_state(self, item_id: str, item_content_index: int) -> ModelAudioState | None:
"""Called when the model wants to get the current playback state."""
return self._states.get((item_id, item_content_index))

def get_last_audio_item(self) -> tuple[str, int] | None:
"""Called when the model wants to get the last audio item ID and content index."""
return self._last_audio_item
9 changes: 9 additions & 0 deletions src/agents/realtime/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from .config import RealtimeAudioFormat


def calculate_audio_length_ms(format: RealtimeAudioFormat | None, bytes: bytes) -> float:
if format and format.startswith("g711"):
return (len(bytes) / 8000) * 1000
return (len(bytes) / 24 / 2) * 1000
94 changes: 94 additions & 0 deletions src/agents/realtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,95 @@
from typing_extensions import NotRequired, TypedDict

from ..util._types import MaybeAwaitable
from ._util import calculate_audio_length_ms
from .config import (
RealtimeAudioFormat,
RealtimeSessionModelSettings,
)
from .model_events import RealtimeModelEvent
from .model_inputs import RealtimeModelSendEvent


class RealtimePlaybackState(TypedDict):
current_item_id: str | None
"""The item ID of the current item being played."""

current_item_content_index: int | None
"""The index of the current item content being played."""

elapsed_ms: float | None
"""The number of milliseconds of audio that have been played."""


class RealtimePlaybackTracker:
"""If you have custom playback logic or expect that audio is played with delays or at different
speeds, create an instance of RealtimePlaybackTracker and pass it to the session. You are
responsible for tracking the audio playback progress and calling `on_play_bytes` or
`on_play_ms` when the user has played some audio."""

def __init__(self) -> None:
self._format: RealtimeAudioFormat | None = None
# (item_id, item_content_index)
self._current_item: tuple[str, int] | None = None
self._elapsed_ms: float | None = None

def on_play_bytes(self, item_id: str, item_content_index: int, bytes: bytes) -> None:
"""Called by you when you have played some audio.

Args:
item_id: The item ID of the audio being played.
item_content_index: The index of the audio content in `item.content`
bytes: The audio bytes that have been fully played.
"""
ms = calculate_audio_length_ms(self._format, bytes)
self.on_play_ms(item_id, item_content_index, ms)

def on_play_ms(self, item_id: str, item_content_index: int, ms: float) -> None:
"""Called by you when you have played some audio.

Args:
item_id: The item ID of the audio being played.
item_content_index: The index of the audio content in `item.content`
ms: The number of milliseconds of audio that have been played.
"""
if self._current_item != (item_id, item_content_index):
self._current_item = (item_id, item_content_index)
self._elapsed_ms = ms
else:
assert self._elapsed_ms is not None
self._elapsed_ms += ms

def on_interrupted(self) -> None:
"""Called by the model when the audio playback has been interrupted."""
self._current_item = None
self._elapsed_ms = None

def set_audio_format(self, format: RealtimeAudioFormat) -> None:
"""Will be called by the model to set the audio format.

Args:
format: The audio format to use.
"""
self._format = format

def get_state(self) -> RealtimePlaybackState:
"""Will be called by the model to get the current playback state."""
if self._current_item is None:
return {
"current_item_id": None,
"current_item_content_index": None,
"elapsed_ms": None,
}
assert self._elapsed_ms is not None

item_id, item_content_index = self._current_item
return {
"current_item_id": item_id,
"current_item_content_index": item_content_index,
"elapsed_ms": self._elapsed_ms,
}


class RealtimeModelListener(abc.ABC):
"""A listener for realtime transport events."""

Expand All @@ -39,6 +121,18 @@ class RealtimeModelConfig(TypedDict):
initial_model_settings: NotRequired[RealtimeSessionModelSettings]
"""The initial model settings to use when connecting."""

playback_tracker: NotRequired[RealtimePlaybackTracker]
"""The playback tracker to use when tracking audio playback progress. If not set, the model will
use a default implementation that assumes audio is played immediately, at realtime speed.

A playback tracker is useful for interruptions. The model generates audio much faster than
realtime playback speed. So if there's an interruption, its useful for the model to know how
much of the audio has been played by the user. In low-latency scenarios, it's fine to assume
that audio is played back immediately at realtime speed. But in scenarios like phone calls or
other remote interactions, you can set a playback tracker that lets the model know when audio
is played to the user.
"""


class RealtimeModel(abc.ABC):
"""Interface for connecting to a realtime model and sending/receiving events."""
Expand Down
82 changes: 55 additions & 27 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from websockets.asyncio.client import ClientConnection

from agents.handoffs import Handoff
from agents.realtime._default_tracker import ModelAudioTracker
from agents.tool import FunctionTool, Tool
from agents.util._types import MaybeAwaitable

Expand All @@ -72,6 +73,8 @@
RealtimeModel,
RealtimeModelConfig,
RealtimeModelListener,
RealtimePlaybackState,
RealtimePlaybackTracker,
)
from .model_events import (
RealtimeModelAudioDoneEvent,
Expand Down Expand Up @@ -133,11 +136,10 @@ def __init__(self) -> None:
self._websocket_task: asyncio.Task[None] | None = None
self._listeners: list[RealtimeModelListener] = []
self._current_item_id: str | None = None
self._audio_start_time: datetime | None = None
self._audio_length_ms: float = 0.0
self._audio_state_tracker: ModelAudioTracker = ModelAudioTracker()
self._ongoing_response: bool = False
self._current_audio_content_index: int | None = None
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
self._playback_tracker: RealtimePlaybackTracker | None = None

async def connect(self, options: RealtimeModelConfig) -> None:
"""Establish a connection to the model and keep it alive."""
Expand All @@ -146,6 +148,8 @@ async def connect(self, options: RealtimeModelConfig) -> None:

model_settings: RealtimeSessionModelSettings = options.get("initial_model_settings", {})

self._playback_tracker = options.get("playback_tracker", RealtimePlaybackTracker())

self.model = model_settings.get("model_name", self.model)
api_key = await get_api_key(options.get("api_key"))

Expand Down Expand Up @@ -294,47 +298,75 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
if event.start_response:
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))

def _get_playback_state(self) -> RealtimePlaybackState:
if self._playback_tracker:
return self._playback_tracker.get_state()

if last_audio_item_id := self._audio_state_tracker.get_last_audio_item():
item_id, item_content_index = last_audio_item_id
audio_state = self._audio_state_tracker.get_state(item_id, item_content_index)
if audio_state:
elapsed_ms = (
datetime.now() - audio_state.initial_received_time
).total_seconds() * 1000
return {
"current_item_id": item_id,
"current_item_content_index": item_content_index,
"elapsed_ms": elapsed_ms,
}

return {
"current_item_id": None,
"current_item_content_index": None,
"elapsed_ms": None,
}

async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
if not self._current_item_id or not self._audio_start_time:
playback_state = self._get_playback_state()
current_item_id = playback_state.get("current_item_id")
current_item_content_index = playback_state.get("current_item_content_index")
elapsed_ms = playback_state.get("elapsed_ms")
if current_item_id is None or elapsed_ms is None:
logger.info(
"Skipping interrupt. "
f"Item id: {current_item_id}, "
f"elapsed ms: {elapsed_ms}, "
f"content index: {current_item_content_index}"
)
return

await self._cancel_response()

elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000
if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms:
current_item_content_index = current_item_content_index or 0
if elapsed_ms > 0:
await self._emit_event(
RealtimeModelAudioInterruptedEvent(
item_id=self._current_item_id,
content_index=self._current_audio_content_index or 0,
item_id=current_item_id,
content_index=current_item_content_index,
)
)
converted = _ConversionHelper.convert_interrupt(
self._current_item_id,
self._current_audio_content_index or 0,
int(elapsed_time_ms),
current_item_id,
current_item_content_index,
int(elapsed_ms),
)
await self._send_raw_message(converted)
await self._cancel_response()

self._current_item_id = None
self._audio_start_time = None
self._audio_length_ms = 0.0
self._current_audio_content_index = None
self._audio_state_tracker.on_interrupted()
if self._playback_tracker:
self._playback_tracker.on_interrupted()

async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None:
"""Send a session update to the model."""
await self._update_session_config(event.session_settings)

async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
"""Handle audio delta events and update audio tracking state."""
self._current_audio_content_index = parsed.content_index
self._current_item_id = parsed.item_id
if self._audio_start_time is None:
self._audio_start_time = datetime.now()
self._audio_length_ms = 0.0

audio_bytes = base64.b64decode(parsed.delta)
# Calculate audio length in ms using 24KHz pcm16le
self._audio_length_ms += self._calculate_audio_length_ms(audio_bytes)

self._audio_state_tracker.on_audio_delta(parsed.item_id, parsed.content_index, audio_bytes)

await self._emit_event(
RealtimeModelAudioEvent(
data=audio_bytes,
Expand All @@ -344,10 +376,6 @@ async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
)
)

def _calculate_audio_length_ms(self, audio_bytes: bytes) -> float:
"""Calculate audio length in milliseconds for 24KHz PCM16LE format."""
return len(audio_bytes) / 24 / 2

async def _handle_output_item(self, item: ConversationItem) -> None:
"""Handle response output item events (function calls and messages)."""
if item.type == "function_call" and item.status == "completed":
Expand Down
2 changes: 2 additions & 0 deletions tests/realtime/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest

from agents import RunContextWrapper
Expand Down
2 changes: 2 additions & 0 deletions tests/realtime/test_conversion_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import base64
from unittest.mock import Mock

Expand Down
Loading