From c9b1bd131404d67829227104bd8c8c84860a08ed Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 11 Jul 2025 20:26:35 -0400 Subject: [PATCH 1/2] [8/n] Make realtime more like the rest of agents sdk Key changes: 1. Transport -> model. 2. Extract any model settings into `RealtimeSessionModelSettings`. 3. RealtimeRunConfig, similar to the RunConfig in `run.py`. Next PR I'll update session to be better. --- .vscode/launch.json | 14 + examples/realtime/demo.py | 29 +- src/agents/realtime/__init__.py | 50 ++- src/agents/realtime/config.py | 51 +-- src/agents/realtime/events.py | 20 +- src/agents/realtime/model.py | 99 ++++++ .../{transport_events.py => model_events.py} | 54 ++-- src/agents/realtime/openai_realtime.py | 117 ++++--- src/agents/realtime/runner.py | 112 +++++++ src/agents/realtime/session.py | 303 +++++++----------- src/agents/realtime/transport.py | 107 ------- ...ansport_events.py => test_model_events.py} | 4 +- 12 files changed, 532 insertions(+), 428 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 src/agents/realtime/model.py rename src/agents/realtime/{transport_events.py => model_events.py} (66%) create mode 100644 src/agents/realtime/runner.py delete mode 100644 src/agents/realtime/transport.py rename tests/realtime/{test_transport_events.py => test_model_events.py} (68%) diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..a75c1414f --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,14 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Python File", + "type": "debugpy", + "request": "launch", + "program": "${file}" + } + ] +} \ No newline at end of file diff --git a/examples/realtime/demo.py b/examples/realtime/demo.py index a2ea96545..61834a4e5 100644 --- a/examples/realtime/demo.py +++ b/examples/realtime/demo.py @@ -5,11 +5,13 @@ import numpy as np +from agents.realtime import RealtimeSession + # Add the current directory to path so we can import ui sys.path.append(os.path.dirname(os.path.abspath(__file__))) from agents import function_tool -from agents.realtime import RealtimeAgent, RealtimeSession, RealtimeSessionEvent +from agents.realtime import RealtimeAgent, RealtimeRunner, RealtimeSessionEvent if TYPE_CHECKING: from .ui import AppUI @@ -38,23 +40,34 @@ def get_weather(city: str) -> str: class Example: def __init__(self) -> None: - self.session = RealtimeSession(agent) self.ui = AppUI() self.ui.connected = asyncio.Event() self.ui.last_audio_item_id = None # Set the audio callback self.ui.set_audio_callback(self.on_audio_recorded) + self.session: RealtimeSession | None = None + async def run(self) -> None: - self.session.add_listener(self.on_event) - await self.session.connect() - self.ui.set_is_connected(True) - await self.ui.run_async() + # Start UI in a separate task instead of waiting for it to complete + ui_task = asyncio.create_task(self.ui.run_async()) + + # Set up session immediately without waiting for UI to finish + runner = RealtimeRunner(agent) + async with await runner.run() as session: + self.session = session + self.ui.set_is_connected(True) + async for event in session: + await self.on_event(event) + + # Wait for UI task to complete when session ends + await ui_task async def on_audio_recorded(self, audio_bytes: bytes) -> None: """Called when audio is recorded by the UI.""" try: # Send the audio to the session + assert self.session is not None await self.session.send_audio(audio_bytes) except Exception as e: self.ui.log_message(f"Error sending audio: {e}") @@ -87,8 +100,8 @@ async def on_event(self, event: RealtimeSessionEvent) -> None: pass elif event.type == "history_added": pass - elif event.type == "raw_transport_event": - self.ui.log_message(f"Raw transport event: {event.data}") + elif event.type == "raw_model_event": + self.ui.log_message(f"Raw model event: {event.data}") else: self.ui.log_message(f"Unknown event type: {event.type}") except Exception as e: diff --git a/src/agents/realtime/__init__.py b/src/agents/realtime/__init__.py index 7e5d4932a..52b59edc3 100644 --- a/src/agents/realtime/__init__.py +++ b/src/agents/realtime/__init__.py @@ -1,5 +1,16 @@ from .agent import RealtimeAgent, RealtimeAgentHooks, RealtimeRunHooks -from .config import APIKeyOrKeyFunc +from .config import ( + RealtimeAudioFormat, + RealtimeClientMessage, + RealtimeInputAudioTranscriptionConfig, + RealtimeModelName, + RealtimeRunConfig, + RealtimeSessionModelSettings, + RealtimeTurnDetectionConfig, + RealtimeUserInput, + RealtimeUserInputMessage, + RealtimeUserInputText, +) from .events import ( RealtimeAgentEndEvent, RealtimeAgentStartEvent, @@ -10,42 +21,49 @@ RealtimeHandoffEvent, RealtimeHistoryAdded, RealtimeHistoryUpdated, - RealtimeRawTransportEvent, + RealtimeRawModelEvent, RealtimeSessionEvent, RealtimeToolEnd, RealtimeToolStart, ) -from .session import RealtimeSession -from .transport import ( - RealtimeModelName, - RealtimeSessionTransport, - RealtimeTransportConnectionOptions, - RealtimeTransportListener, +from .model import ( + RealtimeModel, + RealtimeModelConfig, + RealtimeModelListener, ) +from .runner import RealtimeRunner +from .session import RealtimeSession __all__ = [ "RealtimeAgent", "RealtimeAgentHooks", "RealtimeRunHooks", - "RealtimeSession", - "RealtimeSessionListener", - "RealtimeSessionListenerFunc", - "APIKeyOrKeyFunc", + "RealtimeRunner", + "RealtimeRunConfig", + "RealtimeSessionModelSettings", + "RealtimeInputAudioTranscriptionConfig", + "RealtimeTurnDetectionConfig", + "RealtimeAudioFormat", + "RealtimeClientMessage", + "RealtimeUserInput", + "RealtimeUserInputMessage", + "RealtimeUserInputText", "RealtimeModelName", - "RealtimeSessionTransport", - "RealtimeTransportListener", - "RealtimeTransportConnectionOptions", + "RealtimeModel", + "RealtimeModelListener", + "RealtimeModelConfig", "RealtimeSessionEvent", "RealtimeAgentStartEvent", "RealtimeAgentEndEvent", "RealtimeHandoffEvent", "RealtimeToolStart", "RealtimeToolEnd", - "RealtimeRawTransportEvent", + "RealtimeRawModelEvent", "RealtimeAudioEnd", "RealtimeAudio", "RealtimeAudioInterrupted", "RealtimeError", "RealtimeHistoryUpdated", "RealtimeHistoryAdded", + "RealtimeSession", ] diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index aa15c837d..459a03a9a 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -1,9 +1,7 @@ from __future__ import annotations -import inspect from typing import ( Any, - Callable, Literal, Union, ) @@ -11,8 +9,20 @@ from typing_extensions import NotRequired, TypeAlias, TypedDict from ..model_settings import ToolChoice -from ..tool import FunctionTool -from ..util._types import MaybeAwaitable +from ..tool import Tool + +RealtimeModelName: TypeAlias = Union[ + Literal[ + "gpt-4o-realtime-preview", + "gpt-4o-mini-realtime-preview", + "gpt-4o-realtime-preview-2025-06-03", + "gpt-4o-realtime-preview-2024-12-17", + "gpt-4o-realtime-preview-2024-10-01", + "gpt-4o-mini-realtime-preview-2024-12-17", + ], + str, +] +"""The name of a realtime model.""" class RealtimeClientMessage(TypedDict): @@ -20,7 +30,7 @@ class RealtimeClientMessage(TypedDict): other_data: NotRequired[dict[str, Any]] -class UserInputText(TypedDict): +class RealtimeUserInputText(TypedDict): type: Literal["input_text"] text: str @@ -28,7 +38,7 @@ class UserInputText(TypedDict): class RealtimeUserInputMessage(TypedDict): type: Literal["message"] role: Literal["user"] - content: list[UserInputText] + content: list[RealtimeUserInputText] RealtimeUserInput: TypeAlias = Union[str, RealtimeUserInputMessage] @@ -55,9 +65,11 @@ class RealtimeTurnDetectionConfig(TypedDict): threshold: NotRequired[float] -class RealtimeSessionConfig(TypedDict): - api_key: NotRequired[APIKeyOrKeyFunc] - model: NotRequired[str] +class RealtimeSessionModelSettings(TypedDict): + """Model settings for a realtime model session.""" + + model_name: NotRequired[RealtimeModelName] + instructions: NotRequired[str] modalities: NotRequired[list[Literal["text", "audio"]]] voice: NotRequired[str] @@ -68,24 +80,13 @@ class RealtimeSessionConfig(TypedDict): turn_detection: NotRequired[RealtimeTurnDetectionConfig] tool_choice: NotRequired[ToolChoice] - tools: NotRequired[list[FunctionTool]] - - -APIKeyOrKeyFunc = str | Callable[[], MaybeAwaitable[str]] -"""Either an API key or a function that returns an API key.""" - + tools: NotRequired[list[Tool]] -async def get_api_key(key: APIKeyOrKeyFunc | None) -> str | None: - """Get the API key from the key or key function.""" - if key is None: - return None - elif isinstance(key, str): - return key - result = key() - if inspect.isawaitable(result): - return await result - return result +class RealtimeRunConfig(TypedDict): + model_settings: NotRequired[RealtimeSessionModelSettings] # TODO (rm) Add tracing support # tracing: NotRequired[RealtimeTracingConfig | None] + # TODO (rm) Add guardrail support + # TODO (rm) Add history audio storage config diff --git a/src/agents/realtime/events.py b/src/agents/realtime/events.py index bd6b7b5b0..44a588f03 100644 --- a/src/agents/realtime/events.py +++ b/src/agents/realtime/events.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import Any, Literal, Union @@ -7,7 +9,7 @@ from ..tool import Tool from .agent import RealtimeAgent from .items import RealtimeItem -from .transport_events import RealtimeTransportAudioEvent, RealtimeTransportEvent +from .model_events import RealtimeModelAudioEvent, RealtimeModelEvent @dataclass @@ -93,16 +95,16 @@ class RealtimeToolEnd: @dataclass -class RealtimeRawTransportEvent: - """Forwards raw events from the transport layer.""" +class RealtimeRawModelEvent: + """Forwards raw events from the model layer.""" - data: RealtimeTransportEvent - """The raw data from the transport layer.""" + data: RealtimeModelEvent + """The raw data from the model layer.""" info: RealtimeEventInfo """Common info for all events, such as the context.""" - type: Literal["raw_transport_event"] = "raw_transport_event" + type: Literal["raw_model_event"] = "raw_model_event" @dataclass @@ -119,8 +121,8 @@ class RealtimeAudioEnd: class RealtimeAudio: """Triggered when the agent generates new audio to be played.""" - audio: RealtimeTransportAudioEvent - """The audio event from the transport layer.""" + audio: RealtimeModelAudioEvent + """The audio event from the model layer.""" info: RealtimeEventInfo """Common info for all events, such as the context.""" @@ -187,7 +189,7 @@ class RealtimeHistoryAdded: RealtimeHandoffEvent, RealtimeToolStart, RealtimeToolEnd, - RealtimeRawTransportEvent, + RealtimeRawModelEvent, RealtimeAudioEnd, RealtimeAudio, RealtimeAudioInterrupted, diff --git a/src/agents/realtime/model.py b/src/agents/realtime/model.py new file mode 100644 index 000000000..2b41960e7 --- /dev/null +++ b/src/agents/realtime/model.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import abc +from typing import Any, Callable + +from typing_extensions import NotRequired, TypedDict + +from ..util._types import MaybeAwaitable +from .config import ( + RealtimeClientMessage, + RealtimeSessionModelSettings, + RealtimeUserInput, +) +from .model_events import RealtimeModelEvent, RealtimeModelToolCallEvent + + +class RealtimeModelListener(abc.ABC): + """A listener for realtime transport events.""" + + @abc.abstractmethod + async def on_event(self, event: RealtimeModelEvent) -> None: + """Called when an event is emitted by the realtime transport.""" + pass + + +class RealtimeModelConfig(TypedDict): + """Options for connecting to a realtime model.""" + + api_key: NotRequired[str | Callable[[], MaybeAwaitable[str]]] + """The API key (or function that returns a key) to use when connecting. If unset, the model will + try to use a sane default. For example, the OpenAI Realtime model will try to use the + `OPENAI_API_KEY` environment variable. + """ + + url: NotRequired[str] + """The URL to use when connecting. If unset, the model will use a sane default. For example, + the OpenAI Realtime model will use the default OpenAI WebSocket URL. + """ + + initial_model_settings: NotRequired[RealtimeSessionModelSettings] + + +class RealtimeModel(abc.ABC): + """Interface for connecting to a realtime model and sending/receiving events.""" + + @abc.abstractmethod + async def connect(self, options: RealtimeModelConfig) -> None: + """Establish a connection to the model and keep it alive.""" + pass + + @abc.abstractmethod + def add_listener(self, listener: RealtimeModelListener) -> None: + """Add a listener to the model.""" + pass + + @abc.abstractmethod + def remove_listener(self, listener: RealtimeModelListener) -> None: + """Remove a listener from the model.""" + pass + + @abc.abstractmethod + async def send_event(self, event: RealtimeClientMessage) -> None: + """Send an event to the model.""" + pass + + @abc.abstractmethod + async def send_message( + self, message: RealtimeUserInput, other_event_data: dict[str, Any] | None = None + ) -> None: + """Send a message to the model.""" + pass + + @abc.abstractmethod + async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: + """Send a raw audio chunk to the model. + + Args: + audio: The audio data to send. + commit: Whether to commit the audio buffer to the model. If the model does not do turn + detection, this can be used to indicate the turn is completed. + """ + pass + + @abc.abstractmethod + async def send_tool_output( + self, tool_call: RealtimeModelToolCallEvent, output: str, start_response: bool + ) -> None: + """Send tool output to the model.""" + pass + + @abc.abstractmethod + async def interrupt(self) -> None: + """Interrupt the model. For example, could be triggered by a guardrail.""" + pass + + @abc.abstractmethod + async def close(self) -> None: + """Close the session.""" + pass diff --git a/src/agents/realtime/transport_events.py b/src/agents/realtime/model_events.py similarity index 66% rename from src/agents/realtime/transport_events.py rename to src/agents/realtime/model_events.py index 735577b17..a9769cebd 100644 --- a/src/agents/realtime/transport_events.py +++ b/src/agents/realtime/model_events.py @@ -11,7 +11,7 @@ @dataclass -class RealtimeTransportErrorEvent: +class RealtimeModelErrorEvent: """Represents a transport‑layer error.""" error: Any @@ -20,7 +20,7 @@ class RealtimeTransportErrorEvent: @dataclass -class RealtimeTransportToolCallEvent: +class RealtimeModelToolCallEvent: """Model attempted a tool/function call.""" name: str @@ -34,7 +34,7 @@ class RealtimeTransportToolCallEvent: @dataclass -class RealtimeTransportAudioEvent: +class RealtimeModelAudioEvent: """Raw audio bytes emitted by the model.""" data: bytes @@ -44,21 +44,21 @@ class RealtimeTransportAudioEvent: @dataclass -class RealtimeTransportAudioInterruptedEvent: +class RealtimeModelAudioInterruptedEvent: """Audio interrupted.""" type: Literal["audio_interrupted"] = "audio_interrupted" @dataclass -class RealtimeTransportAudioDoneEvent: +class RealtimeModelAudioDoneEvent: """Audio done.""" type: Literal["audio_done"] = "audio_done" @dataclass -class RealtimeTransportInputAudioTranscriptionCompletedEvent: +class RealtimeModelInputAudioTranscriptionCompletedEvent: """Input audio transcription completed.""" item_id: str @@ -70,7 +70,7 @@ class RealtimeTransportInputAudioTranscriptionCompletedEvent: @dataclass -class RealtimeTransportTranscriptDelta: +class RealtimeModelTranscriptDeltaEvent: """Partial transcript update.""" item_id: str @@ -81,7 +81,7 @@ class RealtimeTransportTranscriptDelta: @dataclass -class RealtimeTransportItemUpdatedEvent: +class RealtimeModelItemUpdatedEvent: """Item added to the history or updated.""" item: RealtimeItem @@ -90,7 +90,7 @@ class RealtimeTransportItemUpdatedEvent: @dataclass -class RealtimeTransportItemDeletedEvent: +class RealtimeModelItemDeletedEvent: """Item deleted from the history.""" item_id: str @@ -99,7 +99,7 @@ class RealtimeTransportItemDeletedEvent: @dataclass -class RealtimeTransportConnectionStatusEvent: +class RealtimeModelConnectionStatusEvent: """Connection status changed.""" status: RealtimeConnectionStatus @@ -108,21 +108,21 @@ class RealtimeTransportConnectionStatusEvent: @dataclass -class RealtimeTransportTurnStartedEvent: +class RealtimeModelTurnStartedEvent: """Triggered when the model starts generating a response for a turn.""" type: Literal["turn_started"] = "turn_started" @dataclass -class RealtimeTransportTurnEndedEvent: +class RealtimeModelTurnEndedEvent: """Triggered when the model finishes generating a response for a turn.""" type: Literal["turn_ended"] = "turn_ended" @dataclass -class RealtimeTransportOtherEvent: +class RealtimeModelOtherEvent: """Used as a catchall for vendor-specific events.""" data: Any @@ -133,18 +133,18 @@ class RealtimeTransportOtherEvent: # TODO (rm) Add usage events -RealtimeTransportEvent: TypeAlias = Union[ - RealtimeTransportErrorEvent, - RealtimeTransportToolCallEvent, - RealtimeTransportAudioEvent, - RealtimeTransportAudioInterruptedEvent, - RealtimeTransportAudioDoneEvent, - RealtimeTransportInputAudioTranscriptionCompletedEvent, - RealtimeTransportTranscriptDelta, - RealtimeTransportItemUpdatedEvent, - RealtimeTransportItemDeletedEvent, - RealtimeTransportConnectionStatusEvent, - RealtimeTransportTurnStartedEvent, - RealtimeTransportTurnEndedEvent, - RealtimeTransportOtherEvent, +RealtimeModelEvent: TypeAlias = Union[ + RealtimeModelErrorEvent, + RealtimeModelToolCallEvent, + RealtimeModelAudioEvent, + RealtimeModelAudioInterruptedEvent, + RealtimeModelAudioDoneEvent, + RealtimeModelInputAudioTranscriptionCompletedEvent, + RealtimeModelTranscriptDeltaEvent, + RealtimeModelItemUpdatedEvent, + RealtimeModelItemDeletedEvent, + RealtimeModelConnectionStatusEvent, + RealtimeModelTurnStartedEvent, + RealtimeModelTurnEndedEvent, + RealtimeModelOtherEvent, ] diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 6dc34bcf1..b3f6eb2e0 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import asyncio import base64 +import inspect import json import os from datetime import datetime -from typing import Any +from typing import Any, Callable import websockets from openai.types.beta.realtime.realtime_server_event import ( @@ -12,52 +15,72 @@ from pydantic import TypeAdapter from websockets.asyncio.client import ClientConnection +from agents.util._types import MaybeAwaitable + from ..exceptions import UserError from ..logger import logger -from .config import RealtimeClientMessage, RealtimeUserInput, get_api_key +from .config import ( + RealtimeClientMessage, + RealtimeSessionModelSettings, + RealtimeUserInput, +) from .items import RealtimeMessageItem, RealtimeToolCallItem -from .transport import ( - RealtimeSessionTransport, - RealtimeTransportConnectionOptions, - RealtimeTransportListener, +from .model import ( + RealtimeModel, + RealtimeModelConfig, + RealtimeModelListener, ) -from .transport_events import ( - RealtimeTransportAudioDoneEvent, - RealtimeTransportAudioEvent, - RealtimeTransportAudioInterruptedEvent, - RealtimeTransportErrorEvent, - RealtimeTransportEvent, - RealtimeTransportInputAudioTranscriptionCompletedEvent, - RealtimeTransportItemDeletedEvent, - RealtimeTransportItemUpdatedEvent, - RealtimeTransportToolCallEvent, - RealtimeTransportTranscriptDelta, - RealtimeTransportTurnEndedEvent, - RealtimeTransportTurnStartedEvent, +from .model_events import ( + RealtimeModelAudioDoneEvent, + RealtimeModelAudioEvent, + RealtimeModelAudioInterruptedEvent, + RealtimeModelErrorEvent, + RealtimeModelEvent, + RealtimeModelInputAudioTranscriptionCompletedEvent, + RealtimeModelItemDeletedEvent, + RealtimeModelItemUpdatedEvent, + RealtimeModelToolCallEvent, + RealtimeModelTranscriptDeltaEvent, + RealtimeModelTurnEndedEvent, + RealtimeModelTurnStartedEvent, ) -class OpenAIRealtimeWebSocketTransport(RealtimeSessionTransport): - """A transport layer for realtime sessions that uses OpenAI's WebSocket API.""" +async def get_api_key(key: str | Callable[[], MaybeAwaitable[str]] | None) -> str | None: + if isinstance(key, str): + return key + elif callable(key): + result = key() + if inspect.isawaitable(result): + return await result + return result + + return os.getenv("OPENAI_API_KEY") + + +class OpenAIRealtimeWebSocketModel(RealtimeModel): + """A model that uses OpenAI's WebSocket API.""" def __init__(self) -> None: self.model = "gpt-4o-realtime-preview" # Default model self._websocket: ClientConnection | None = None self._websocket_task: asyncio.Task[None] | None = None - self._listeners: list[RealtimeTransportListener] = [] + self._listeners: list[RealtimeModelListener] = [] self._current_item_id: str | None = None self._audio_start_time: datetime | None = None self._audio_length_ms: float = 0.0 self._ongoing_response: bool = False self._current_audio_content_index: int | None = None - async def connect(self, options: RealtimeTransportConnectionOptions) -> None: + async def connect(self, options: RealtimeModelConfig) -> None: """Establish a connection to the model and keep it alive.""" assert self._websocket is None, "Already connected" assert self._websocket_task is None, "Already connected" - self.model = options.get("model", self.model) - api_key = await get_api_key(options.get("api_key", os.getenv("OPENAI_API_KEY"))) + model_settings: RealtimeSessionModelSettings = options.get("initial_model_settings", {}) + + self.model = model_settings.get("model_name", self.model) + api_key = await get_api_key(options.get("api_key")) if not api_key: raise UserError("API key is required but was not provided.") @@ -71,15 +94,15 @@ async def connect(self, options: RealtimeTransportConnectionOptions) -> None: self._websocket = await websockets.connect(url, additional_headers=headers) self._websocket_task = asyncio.create_task(self._listen_for_messages()) - def add_listener(self, listener: RealtimeTransportListener) -> None: - """Add a listener to the transport.""" + def add_listener(self, listener: RealtimeModelListener) -> None: + """Add a listener to the model.""" self._listeners.append(listener) - async def remove_listener(self, listener: RealtimeTransportListener) -> None: - """Remove a listener from the transport.""" + def remove_listener(self, listener: RealtimeModelListener) -> None: + """Remove a listener from the model.""" self._listeners.remove(listener) - async def _emit_event(self, event: RealtimeTransportEvent) -> None: + async def _emit_event(self, event: RealtimeModelEvent) -> None: """Emit an event to the listeners.""" for listener in self._listeners: await listener.on_event(event) @@ -154,7 +177,7 @@ async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: await self.send_event({"type": "input_audio_buffer.commit"}) async def send_tool_output( - self, tool_call: RealtimeTransportToolCallEvent, output: str, start_response: bool + self, tool_call: RealtimeModelToolCallEvent, output: str, start_response: bool ) -> None: """Send tool output to the model.""" await self.send_event( @@ -179,7 +202,7 @@ async def send_tool_output( name=tool_call.name, output=output, ) - await self._emit_event(RealtimeTransportItemUpdatedEvent(item=tool_item)) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_item)) if start_response: await self.send_event({"type": "response.create"}) @@ -193,7 +216,7 @@ async def interrupt(self) -> None: elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000 if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms: - await self._emit_event(RealtimeTransportAudioInterruptedEvent()) + await self._emit_event(RealtimeModelAudioInterruptedEvent()) await self.send_event( { "type": "conversation.item.truncate", @@ -231,9 +254,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): ).validate_python(event) except Exception as e: logger.error(f"Invalid event: {event} - {e}") - await self._emit_event( - RealtimeTransportErrorEvent(error=f"Invalid event: {event} - {e}") - ) + # await self._emit_event(RealtimeModelErrorEvent(error=f"Invalid event: {event} - {e}")) return if parsed.type == "response.audio.delta": @@ -247,25 +268,25 @@ async def _handle_ws_event(self, event: dict[str, Any]): # Calculate audio length in ms using 24KHz pcm16le self._audio_length_ms += len(audio_bytes) / 24 / 2 await self._emit_event( - RealtimeTransportAudioEvent(data=audio_bytes, response_id=parsed.response_id) + RealtimeModelAudioEvent(data=audio_bytes, response_id=parsed.response_id) ) elif parsed.type == "response.audio.done": - await self._emit_event(RealtimeTransportAudioDoneEvent()) + await self._emit_event(RealtimeModelAudioDoneEvent()) elif parsed.type == "input_audio_buffer.speech_started": await self.interrupt() elif parsed.type == "response.created": self._ongoing_response = True - await self._emit_event(RealtimeTransportTurnStartedEvent()) + await self._emit_event(RealtimeModelTurnStartedEvent()) elif parsed.type == "response.done": self._ongoing_response = False - await self._emit_event(RealtimeTransportTurnEndedEvent()) + await self._emit_event(RealtimeModelTurnEndedEvent()) elif parsed.type == "session.created": # TODO (rm) tracing stuff here pass elif parsed.type == "error": - await self._emit_event(RealtimeTransportErrorEvent(error=parsed.error)) + await self._emit_event(RealtimeModelErrorEvent(error=parsed.error)) elif parsed.type == "conversation.item.deleted": - await self._emit_event(RealtimeTransportItemDeletedEvent(item_id=parsed.item_id)) + await self._emit_event(RealtimeModelItemDeletedEvent(item_id=parsed.item_id)) elif ( parsed.type == "conversation.item.created" or parsed.type == "conversation.item.retrieved" @@ -284,7 +305,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): "status": "in_progress", } ) - await self._emit_event(RealtimeTransportItemUpdatedEvent(item=message_item)) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item)) elif ( parsed.type == "conversation.item.input_audio_transcription.completed" or parsed.type == "conversation.item.truncated" @@ -299,13 +320,13 @@ async def _handle_ws_event(self, event: dict[str, Any]): ) if parsed.type == "conversation.item.input_audio_transcription.completed": await self._emit_event( - RealtimeTransportInputAudioTranscriptionCompletedEvent( + RealtimeModelInputAudioTranscriptionCompletedEvent( item_id=parsed.item_id, transcript=parsed.transcript ) ) elif parsed.type == "response.audio_transcript.delta": await self._emit_event( - RealtimeTransportTranscriptDelta( + RealtimeModelTranscriptDeltaEvent( item_id=parsed.item_id, delta=parsed.delta, response_id=parsed.response_id ) ) @@ -333,9 +354,9 @@ async def _handle_ws_event(self, event: dict[str, Any]): name=item.name or "", output=None, ) - await self._emit_event(RealtimeTransportItemUpdatedEvent(item=tool_call)) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_call)) await self._emit_event( - RealtimeTransportToolCallEvent( + RealtimeModelToolCallEvent( call_id=item.id or "", name=item.name or "", arguments=item.arguments or "", @@ -352,4 +373,4 @@ async def _handle_ws_event(self, event: dict[str, Any]): "status": "in_progress", } ) - await self._emit_event(RealtimeTransportItemUpdatedEvent(item=message_item)) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item)) diff --git a/src/agents/realtime/runner.py b/src/agents/realtime/runner.py new file mode 100644 index 000000000..4470ab220 --- /dev/null +++ b/src/agents/realtime/runner.py @@ -0,0 +1,112 @@ +"""Minimal realtime session implementation for voice agents.""" + +from __future__ import annotations + +import asyncio + +from ..run_context import RunContextWrapper, TContext +from .agent import RealtimeAgent +from .config import ( + RealtimeRunConfig, + RealtimeSessionModelSettings, +) +from .model import ( + RealtimeModel, + RealtimeModelConfig, +) +from .openai_realtime import OpenAIRealtimeWebSocketModel +from .session import RealtimeSession + + +class RealtimeRunner: + """A `RealtimeRunner` is the equivalent of `Runner` for realtime agents. It automatically + handles multiple turns by maintaining a persistent connection with the underlying model + layer. + + The session manages the local history copy, executes tools, runs guardrails and facilitates + handoffs between agents. + + Since this code runs on your server, it uses WebSockets by default. You can optionally create + your own custom model layer by implementing the `RealtimeModel` interface. + """ + + def __init__( + self, + starting_agent: RealtimeAgent, + *, + model: RealtimeModel | None = None, + config: RealtimeRunConfig | None = None, + ) -> None: + """Initialize the realtime runner. + + Args: + starting_agent: The agent to start the session with. + context: The context to use for the session. + model: The model to use. If not provided, will use a default OpenAI realtime model. + config: Override parameters to use for the entire run. + """ + self._starting_agent = starting_agent + self._config = config + self._model = model or OpenAIRealtimeWebSocketModel() + + async def run( + self, *, context: TContext | None = None, model_config: RealtimeModelConfig | None = None + ) -> RealtimeSession: + """Start and returns a realtime session. + + Returns: + RealtimeSession: A session object that allows bidirectional communication with the + realtime model. + + Example: + ```python + runner = RealtimeRunner(agent) + async with await runner.run() as session: + await session.send_message("Hello") + async for event in session: + print(event) + ``` + """ + model_settings = await self._get_model_settings( + agent=self._starting_agent, + initial_settings=model_config.get("initial_model_settings") if model_config else None, + overrides=self._config.get("model_settings") if self._config else None, + ) + + model_config = model_config.copy() if model_config else {} + model_config["initial_model_settings"] = model_settings + + # Create and return the connection + session = RealtimeSession( + model=self._model, + agent=self._starting_agent, + context=context, + model_config=model_config, + ) + + return session + + async def _get_model_settings( + self, + agent: RealtimeAgent, + context: TContext | None = None, + initial_settings: RealtimeSessionModelSettings | None = None, + overrides: RealtimeSessionModelSettings | None = None, + ) -> RealtimeSessionModelSettings: + context_wrapper = RunContextWrapper(context) + model_settings = initial_settings.copy() if initial_settings else {} + + instructions, tools = await asyncio.gather( + agent.get_system_prompt(context_wrapper), + agent.get_all_tools(context_wrapper), + ) + + if instructions is not None: + model_settings["instructions"] = instructions + if tools is not None: + model_settings["tools"] = tools + + if overrides: + model_settings.update(overrides) + + return model_settings diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index fd2a0a3ff..927e74e9f 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -1,20 +1,17 @@ -"""Minimal realtime session implementation for voice agents.""" - from __future__ import annotations -import abc import asyncio -from collections.abc import Awaitable -from typing import Any, Callable, Literal +from collections.abc import AsyncIterator +from typing import Any -from typing_extensions import TypeAlias, assert_never +from typing_extensions import assert_never from ..handoffs import Handoff -from ..run_context import RunContextWrapper +from ..run_context import RunContextWrapper, TContext from ..tool import FunctionTool from ..tool_context import ToolContext from .agent import RealtimeAgent -from .config import APIKeyOrKeyFunc, RealtimeSessionConfig, RealtimeUserInput +from .config import RealtimeUserInput from .events import ( RealtimeAgentEndEvent, RealtimeAgentStartEvent, @@ -23,253 +20,186 @@ RealtimeAudioInterrupted, RealtimeError, RealtimeEventInfo, - RealtimeHandoffEvent, # noqa: F401 RealtimeHistoryAdded, RealtimeHistoryUpdated, - RealtimeRawTransportEvent, + RealtimeRawModelEvent, RealtimeSessionEvent, RealtimeToolEnd, RealtimeToolStart, ) from .items import InputAudio, InputText, RealtimeItem -from .openai_realtime import OpenAIRealtimeWebSocketTransport -from .transport import ( - RealtimeModelName, - RealtimeSessionTransport, - RealtimeTransportConnectionOptions, - RealtimeTransportListener, -) -from .transport_events import ( - RealtimeTransportEvent, - RealtimeTransportInputAudioTranscriptionCompletedEvent, - RealtimeTransportToolCallEvent, +from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener +from .model_events import ( + RealtimeModelEvent, + RealtimeModelInputAudioTranscriptionCompletedEvent, + RealtimeModelToolCallEvent, ) -class RealtimeSessionListener(abc.ABC): - """A listener for realtime session events.""" - - @abc.abstractmethod - async def on_event(self, event: RealtimeSessionEvent) -> None: - """Called when an event is emitted by the realtime session.""" - pass - - -RealtimeSessionListenerFunc: TypeAlias = Callable[[RealtimeSessionEvent], Awaitable[None]] -"""A function that can be used as a listener for realtime session events.""" - - -class _RealtimeFuncListener(RealtimeSessionListener): - """A listener that wraps a function.""" - - def __init__(self, func: RealtimeSessionListenerFunc) -> None: - self._func = func - - async def on_event(self, event: RealtimeSessionEvent) -> None: - """Call the wrapped function with the event.""" - await self._func(event) - - -class RealtimeSession(RealtimeTransportListener): - """A `RealtimeSession` is the equivalent of `Runner` for realtime agents. It automatically - handles multiple turns by maintaining a persistent connection with the underlying transport - layer. - - The session manages the local history copy, executes tools, runs guardrails and facilitates - handoffs between agents. - - Since this code runs on your server, it uses WebSockets by default. You can optionally create - your own custom transport layer by implementing the `RealtimeSessionTransport` interface. +class RealtimeSession(RealtimeModelListener): + """A connection to a realtime model. It streams events from the model to you, and allows you to + send messages and audio to the model. + + Example: + ```python + runner = RealtimeRunner(agent) + async with await runner.run() as session: + # Send messages + await session.send_message("Hello") + await session.send_audio(audio_bytes) + + # Stream events + async for event in session: + if event.type == "audio": + # Handle audio event + pass + ``` """ def __init__( self, - starting_agent: RealtimeAgent, - *, - context: Any | None = None, - transport: Literal["websocket"] | RealtimeSessionTransport = "websocket", - api_key: APIKeyOrKeyFunc | None = None, - model: RealtimeModelName | None = None, - config: RealtimeSessionConfig | None = None, - # TODO (rm) Add guardrail support - # TODO (rm) Add tracing support - # TODO (rm) Add history audio storage config + model: RealtimeModel, + agent: RealtimeAgent, + context: TContext | None, + model_config: RealtimeModelConfig | None = None, ) -> None: - """Initialize the realtime session. + """Initialize the session. Args: - starting_agent: The agent to start the session with. - context: The context to use for the session. - transport: The transport to use for the session. Defaults to using websockets. - api_key: The API key to use for the session. - model: The model to use. Must be a realtime model. - config: Override parameters to use. + model: The model to use. + agent: The current agent. + context_wrapper: The context wrapper. + event_info: Event info object. + history: The conversation history. + model_config: Model configuration. """ - self._current_agent = starting_agent + self._model = model + self._current_agent = agent self._context_wrapper = RunContextWrapper(context) self._event_info = RealtimeEventInfo(context=self._context_wrapper) - self._override_config = config self._history: list[RealtimeItem] = [] - self._model = model - self._api_key = api_key - - self._listeners: list[RealtimeSessionListener] = [] - - if transport == "websocket": - self._transport: RealtimeSessionTransport = OpenAIRealtimeWebSocketTransport() - else: - self._transport = transport + self._model_config = model_config or {} + self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue() + self._closed = False + self._background_task: asyncio.Task[None] | None = None async def __aenter__(self) -> RealtimeSession: - """Async context manager entry.""" - await self.connect() - return self - - async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None: - """Async context manager exit.""" - await self.end() - - async def connect(self) -> None: - """Start the session: connect to the model and start the connection.""" - self._transport.add_listener(self) - - config = await self.create_session_config( - overrides=self._override_config, - ) - - options: RealtimeTransportConnectionOptions = { - "initial_session_config": config, - } - - if config.get("api_key") is not None: - options["api_key"] = config["api_key"] - elif self._api_key is not None: - options["api_key"] = self._api_key - - if config.get("model") is not None: - options["model"] = config["model"] - elif self._model is not None: - options["model"] = self._model + """Start the session by connecting to the model. After this, you will be able to stream + events from the model and send messages and audio to the model. + """ + # Add ourselves as a listener + self._model.add_listener(self) - await self._transport.connect(options) + # Connect to the model + await self._model.connect(self._model_config) - await self._emit_event( + # Emit initial history update + await self._put_event( RealtimeHistoryUpdated( history=self._history, info=self._event_info, ) ) - async def end(self) -> None: - """End the session: disconnect from the model and close the connection.""" - pass - - def add_listener(self, listener: RealtimeSessionListener | RealtimeSessionListenerFunc) -> None: - """Add a listener to the session.""" - if isinstance(listener, RealtimeSessionListener): - self._listeners.append(listener) - else: - self._listeners.append(_RealtimeFuncListener(listener)) - - def remove_listener( - self, listener: RealtimeSessionListener | RealtimeSessionListenerFunc - ) -> None: - """Remove a listener from the session.""" - if isinstance(listener, RealtimeSessionListener): - self._listeners.remove(listener) - else: - for x in self._listeners: - if isinstance(x, _RealtimeFuncListener) and x._func == listener: - self._listeners.remove(x) - break - - async def create_session_config( - self, overrides: RealtimeSessionConfig | None = None - ) -> RealtimeSessionConfig: - """Create the session config.""" - agent = self._current_agent - instructions, tools = await asyncio.gather( - agent.get_system_prompt(self._context_wrapper), - agent.get_all_tools(self._context_wrapper), - ) - config = RealtimeSessionConfig() - - if self._model is not None: - config["model"] = self._model - if instructions is not None: - config["instructions"] = instructions - if tools is not None: - config["tools"] = [tool for tool in tools if isinstance(tool, FunctionTool)] + return self - if overrides: - config.update(overrides) + async def enter(self) -> RealtimeSession: + """Enter the async context manager. We strongly recommend using the async context manager + pattern instead of this method. If you use this, you need to manually call `close()` when + you are done. + """ + return await self.__aenter__() - return config + async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None: + """End the session.""" + await self.close() + + async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]: + """Iterate over events from the session.""" + while not self._closed: + try: + event = await self._event_queue.get() + yield event + except asyncio.CancelledError: + break + + async def close(self) -> None: + """Close the session.""" + self._closed = True + self._model.remove_listener(self) + await self._model.close() + + # Cancel any background tasks + if self._background_task and not self._background_task.done(): + self._background_task.cancel() + try: + await self._background_task + except asyncio.CancelledError: + pass async def send_message(self, message: RealtimeUserInput) -> None: """Send a message to the model.""" - await self._transport.send_message(message) + await self._model.send_message(message) async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: """Send a raw audio chunk to the model.""" - await self._transport.send_audio(audio, commit=commit) + await self._model.send_audio(audio, commit=commit) async def interrupt(self) -> None: """Interrupt the model.""" - await self._transport.interrupt() + await self._model.interrupt() - async def on_event(self, event: RealtimeTransportEvent) -> None: - """Called when an event is emitted by the realtime transport.""" - await self._emit_event(RealtimeRawTransportEvent(data=event, info=self._event_info)) + async def on_event(self, event: RealtimeModelEvent) -> None: + await self._put_event(RealtimeRawModelEvent(data=event, info=self._event_info)) if event.type == "error": - await self._emit_event(RealtimeError(info=self._event_info, error=event.error)) + await self._put_event(RealtimeError(info=self._event_info, error=event.error)) elif event.type == "function_call": - await self._handle_tool_call(event) + # Handle tool calls in the background to avoid blocking event stream + self._background_task = asyncio.create_task(self._handle_tool_call(event)) elif event.type == "audio": - await self._emit_event(RealtimeAudio(info=self._event_info, audio=event)) + await self._put_event(RealtimeAudio(info=self._event_info, audio=event)) elif event.type == "audio_interrupted": - await self._emit_event(RealtimeAudioInterrupted(info=self._event_info)) + await self._put_event(RealtimeAudioInterrupted(info=self._event_info)) elif event.type == "audio_done": - await self._emit_event(RealtimeAudioEnd(info=self._event_info)) + await self._put_event(RealtimeAudioEnd(info=self._event_info)) elif event.type == "conversation.item.input_audio_transcription.completed": self._history = self._get_new_history(self._history, event) - await self._emit_event( + await self._put_event( RealtimeHistoryUpdated(info=self._event_info, history=self._history) ) elif event.type == "transcript_delta": # TODO (rm) Add guardrails pass elif event.type == "item_updated": - is_new = any(item.item_id == event.item.item_id for item in self._history) + is_new = not any(item.item_id == event.item.item_id for item in self._history) self._history = self._get_new_history(self._history, event.item) if is_new: new_item = next( item for item in self._history if item.item_id == event.item.item_id ) - await self._emit_event(RealtimeHistoryAdded(info=self._event_info, item=new_item)) + await self._put_event(RealtimeHistoryAdded(info=self._event_info, item=new_item)) else: - await self._emit_event( + await self._put_event( RealtimeHistoryUpdated(info=self._event_info, history=self._history) ) - pass elif event.type == "item_deleted": deleted_id = event.item_id self._history = [item for item in self._history if item.item_id != deleted_id] - await self._emit_event( + await self._put_event( RealtimeHistoryUpdated(info=self._event_info, history=self._history) ) elif event.type == "connection_status": pass elif event.type == "turn_started": - await self._emit_event( + await self._put_event( RealtimeAgentStartEvent( agent=self._current_agent, info=self._event_info, ) ) elif event.type == "turn_ended": - await self._emit_event( + await self._put_event( RealtimeAgentEndEvent( agent=self._current_agent, info=self._event_info, @@ -280,17 +210,18 @@ async def on_event(self, event: RealtimeTransportEvent) -> None: else: assert_never(event) - async def _emit_event(self, event: RealtimeSessionEvent) -> None: - """Emit an event to the listeners.""" - await asyncio.gather(*[listener.on_event(event) for listener in self._listeners]) + async def _put_event(self, event: RealtimeSessionEvent) -> None: + """Put an event into the queue.""" + await self._event_queue.put(event) - async def _handle_tool_call(self, event: RealtimeTransportToolCallEvent) -> None: + async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: + """Handle a tool call event.""" all_tools = await self._current_agent.get_all_tools(self._context_wrapper) function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} handoff_map = {tool.name: tool for tool in all_tools if isinstance(tool, Handoff)} if event.name in function_map: - await self._emit_event( + await self._put_event( RealtimeToolStart( info=self._event_info, tool=function_map[event.name], @@ -302,9 +233,9 @@ async def _handle_tool_call(self, event: RealtimeTransportToolCallEvent) -> None tool_context = ToolContext.from_agent_context(self._context_wrapper, event.call_id) result = await func_tool.on_invoke_tool(tool_context, event.arguments) - await self._transport.send_tool_output(event, str(result), True) + await self._model.send_tool_output(event, str(result), True) - await self._emit_event( + await self._put_event( RealtimeToolEnd( info=self._event_info, tool=func_tool, @@ -322,10 +253,10 @@ async def _handle_tool_call(self, event: RealtimeTransportToolCallEvent) -> None def _get_new_history( self, old_history: list[RealtimeItem], - event: RealtimeTransportInputAudioTranscriptionCompletedEvent | RealtimeItem, + event: RealtimeModelInputAudioTranscriptionCompletedEvent | RealtimeItem, ) -> list[RealtimeItem]: # Merge transcript into placeholder input_audio message. - if isinstance(event, RealtimeTransportInputAudioTranscriptionCompletedEvent): + if isinstance(event, RealtimeModelInputAudioTranscriptionCompletedEvent): new_history: list[RealtimeItem] = [] for item in old_history: if item.item_id == event.item_id and item.type == "message" and item.role == "user": @@ -355,7 +286,7 @@ def _get_new_history( new_history[existing_index] = event return new_history # Otherwise, insert it after the previous_item_id if that is set - elif item.previous_item_id: + elif event.previous_item_id: # Insert the new item after the previous item previous_index = next( (i for i, item in enumerate(old_history) if item.item_id == event.previous_item_id), diff --git a/src/agents/realtime/transport.py b/src/agents/realtime/transport.py deleted file mode 100644 index 18290d128..000000000 --- a/src/agents/realtime/transport.py +++ /dev/null @@ -1,107 +0,0 @@ -import abc -from typing import Any, Literal, Union - -from typing_extensions import NotRequired, TypeAlias, TypedDict - -from .config import APIKeyOrKeyFunc, RealtimeClientMessage, RealtimeSessionConfig, RealtimeUserInput -from .transport_events import RealtimeTransportEvent, RealtimeTransportToolCallEvent - -RealtimeModelName: TypeAlias = Union[ - Literal[ - "gpt-4o-realtime-preview", - "gpt-4o-mini-realtime-preview", - "gpt-4o-realtime-preview-2025-06-03", - "gpt-4o-realtime-preview-2024-12-17", - "gpt-4o-realtime-preview-2024-10-01", - "gpt-4o-mini-realtime-preview-2024-12-17", - ], - str, -] -"""The name of a realtime model.""" - - -class RealtimeTransportListener(abc.ABC): - """A listener for realtime transport events.""" - - @abc.abstractmethod - async def on_event(self, event: RealtimeTransportEvent) -> None: - """Called when an event is emitted by the realtime transport.""" - pass - - -class RealtimeTransportConnectionOptions(TypedDict): - """Options for connecting to a realtime transport.""" - - api_key: NotRequired[APIKeyOrKeyFunc] - """The API key to use for the transport. If unset, the transport will attempt to use the - `OPENAI_API_KEY` environment variable. - """ - - model: NotRequired[str] - """The model to use.""" - - url: NotRequired[str] - """The URL to use for the transport. If unset, the transport will use the default OpenAI - WebSocket URL. - """ - - initial_session_config: NotRequired[RealtimeSessionConfig] - - -class RealtimeSessionTransport(abc.ABC): - """A transport layer for realtime sessions.""" - - @abc.abstractmethod - async def connect(self, options: RealtimeTransportConnectionOptions) -> None: - """Establish a connection to the model and keep it alive.""" - pass - - @abc.abstractmethod - def add_listener(self, listener: RealtimeTransportListener) -> None: - """Add a listener to the transport.""" - pass - - @abc.abstractmethod - async def remove_listener(self, listener: RealtimeTransportListener) -> None: - """Remove a listener from the transport.""" - pass - - @abc.abstractmethod - async def send_event(self, event: RealtimeClientMessage) -> None: - """Send an event to the model.""" - pass - - @abc.abstractmethod - async def send_message( - self, message: RealtimeUserInput, other_event_data: dict[str, Any] | None = None - ) -> None: - """Send a message to the model.""" - pass - - @abc.abstractmethod - async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: - """Send a raw audio chunk to the model. - - Args: - audio: The audio data to send. - commit: Whether to commit the audio buffer to the model. If the model does not do turn - detection, this can be used to indicate the turn is completed. - """ - pass - - @abc.abstractmethod - async def send_tool_output( - self, tool_call: RealtimeTransportToolCallEvent, output: str, start_response: bool - ) -> None: - """Send tool output to the model.""" - pass - - @abc.abstractmethod - async def interrupt(self) -> None: - """Interrupt the model. For example, could be triggered by a guardrail.""" - pass - - @abc.abstractmethod - async def close(self) -> None: - """Close the session.""" - pass diff --git a/tests/realtime/test_transport_events.py b/tests/realtime/test_model_events.py similarity index 68% rename from tests/realtime/test_transport_events.py rename to tests/realtime/test_model_events.py index 2219303d0..b8696cc29 100644 --- a/tests/realtime/test_transport_events.py +++ b/tests/realtime/test_model_events.py @@ -1,11 +1,11 @@ from typing import get_args -from agents.realtime.transport_events import RealtimeTransportEvent +from agents.realtime.model_events import RealtimeModelEvent def test_all_events_have_type() -> None: """Test that all events have a type.""" - events = get_args(RealtimeTransportEvent) + events = get_args(RealtimeModelEvent) assert len(events) > 0 for event in events: assert event.type is not None From e29ece8a20cb1d5615e8cca4995278a70b4c4235 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 11 Jul 2025 20:33:07 -0400 Subject: [PATCH 2/2] Tests for realtime runner --- tests/realtime/test_runner.py | 224 ++++++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 tests/realtime/test_runner.py diff --git a/tests/realtime/test_runner.py b/tests/realtime/test_runner.py new file mode 100644 index 000000000..aabdff140 --- /dev/null +++ b/tests/realtime/test_runner.py @@ -0,0 +1,224 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from inline_snapshot import snapshot + +from agents.realtime.agent import RealtimeAgent +from agents.realtime.config import RealtimeRunConfig, RealtimeSessionModelSettings +from agents.realtime.model import RealtimeModel, RealtimeModelConfig +from agents.realtime.runner import RealtimeRunner +from agents.realtime.session import RealtimeSession + + +class MockRealtimeModel(RealtimeModel): + async def connect(self, options=None): + pass + + def add_listener(self, listener): + pass + + def remove_listener(self, listener): + pass + + async def send_event(self, event): + pass + + async def send_message(self, message, other_event_data=None): + pass + + async def send_audio(self, audio, commit=False): + pass + + async def send_tool_output(self, tool_call, output, start_response=True): + pass + + async def interrupt(self): + pass + + async def close(self): + pass + + +@pytest.fixture +def mock_agent(): + agent = Mock(spec=RealtimeAgent) + agent.get_system_prompt = AsyncMock(return_value="Test instructions") + agent.get_all_tools = AsyncMock(return_value=[{"type": "function", "name": "test_tool"}]) + return agent + + +@pytest.fixture +def mock_model(): + return MockRealtimeModel() + + +@pytest.mark.asyncio +async def test_run_creates_session_with_no_settings(mock_agent, mock_model): + """Test that run() creates a session correctly if no settings are provided""" + runner = RealtimeRunner(mock_agent, model=mock_model) + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + session = await runner.run() + + # Verify session was created with correct parameters + mock_session_class.assert_called_once() + call_args = mock_session_class.call_args + + assert call_args[1]["model"] == mock_model + assert call_args[1]["agent"] == mock_agent + assert call_args[1]["context"] is None + + # Verify model_config contains expected settings from agent + model_config = call_args[1]["model_config"] + assert model_config == snapshot( + { + "initial_model_settings": { + "instructions": "Test instructions", + "tools": [{"type": "function", "name": "test_tool"}], + } + } + ) + + assert session == mock_session + + +@pytest.mark.asyncio +async def test_run_creates_session_with_settings_only_in_init(mock_agent, mock_model): + """Test that it creates a session with the right settings if they are provided only in init""" + config = RealtimeRunConfig( + model_settings=RealtimeSessionModelSettings(model_name="gpt-4o-realtime", voice="nova") + ) + runner = RealtimeRunner(mock_agent, model=mock_model, config=config) + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + _ = await runner.run() + + # Verify session was created with config overrides + call_args = mock_session_class.call_args + model_config = call_args[1]["model_config"] + + # Should have agent settings plus config overrides + assert model_config == snapshot( + { + "initial_model_settings": { + "instructions": "Test instructions", + "tools": [{"type": "function", "name": "test_tool"}], + "model_name": "gpt-4o-realtime", + "voice": "nova", + } + } + ) + + +@pytest.mark.asyncio +async def test_run_creates_session_with_settings_in_both_init_and_run_overrides( + mock_agent, mock_model +): + """Test settings in both init and run() - init should override run()""" + init_config = RealtimeRunConfig( + model_settings=RealtimeSessionModelSettings(model_name="gpt-4o-realtime", voice="nova") + ) + runner = RealtimeRunner(mock_agent, model=mock_model, config=init_config) + + run_model_config: RealtimeModelConfig = { + "initial_model_settings": RealtimeSessionModelSettings( + voice="alloy", input_audio_format="pcm16" + ) + } + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + _ = await runner.run(model_config=run_model_config) + + # Verify run() settings override init settings + call_args = mock_session_class.call_args + model_config = call_args[1]["model_config"] + + # Should have agent settings, then init config, then run config overrides + assert model_config == snapshot( + { + "initial_model_settings": { + "voice": "nova", + "input_audio_format": "pcm16", + "instructions": "Test instructions", + "tools": [{"type": "function", "name": "test_tool"}], + "model_name": "gpt-4o-realtime", + } + } + ) + + +@pytest.mark.asyncio +async def test_run_creates_session_with_settings_only_in_run(mock_agent, mock_model): + """Test settings provided only in run()""" + runner = RealtimeRunner(mock_agent, model=mock_model) + + run_model_config: RealtimeModelConfig = { + "initial_model_settings": RealtimeSessionModelSettings( + model_name="gpt-4o-realtime-preview", voice="shimmer", modalities=["text", "audio"] + ) + } + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + _ = await runner.run(model_config=run_model_config) + + # Verify run() settings are applied + call_args = mock_session_class.call_args + model_config = call_args[1]["model_config"] + + # Should have agent settings plus run() settings + assert model_config == snapshot( + { + "initial_model_settings": { + "model_name": "gpt-4o-realtime-preview", + "voice": "shimmer", + "modalities": ["text", "audio"], + "instructions": "Test instructions", + "tools": [{"type": "function", "name": "test_tool"}], + } + } + ) + + +@pytest.mark.asyncio +async def test_run_with_context_parameter(mock_agent, mock_model): + """Test that context parameter is passed through to session""" + runner = RealtimeRunner(mock_agent, model=mock_model) + test_context = {"user_id": "test123"} + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + await runner.run(context=test_context) + + call_args = mock_session_class.call_args + assert call_args[1]["context"] == test_context + + +@pytest.mark.asyncio +async def test_get_model_settings_with_none_values(mock_model): + """Test _get_model_settings handles None values from agent properly""" + agent = Mock(spec=RealtimeAgent) + agent.get_system_prompt = AsyncMock(return_value=None) + agent.get_all_tools = AsyncMock(return_value=None) + + runner = RealtimeRunner(agent, model=mock_model) + + with patch("agents.realtime.runner.RealtimeSession"): + await runner.run() + + # Should not crash and agent methods should be called + agent.get_system_prompt.assert_called_once() + agent.get_all_tools.assert_called_once()