Skip to content

Realtime: forward exceptions from transport layer #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mcp/prompt_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def get_instructions_from_prompt(mcp_server: MCPServer, prompt_name: str,
try:
prompt_result = await mcp_server.get_prompt(prompt_name, kwargs)
content = prompt_result.messages[0].content
if hasattr(content, 'text'):
if hasattr(content, "text"):
instructions = content.text
else:
instructions = str(content)
Expand Down
3 changes: 3 additions & 0 deletions src/agents/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,18 @@ def validate_from_none(value: None) -> _Omit:
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
)


@dataclass
class MCPToolChoice:
server_label: str
name: str


Omit = Annotated[_Omit, _OmitTypeAnnotation]
Headers: TypeAlias = Mapping[str, Union[str, Omit]]
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None]


@dataclass
class ModelSettings:
"""Settings to use when calling an LLM.
Expand Down
2 changes: 1 addition & 1 deletion src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def convert_tool_choice(
elif tool_choice == "mcp":
# Note that this is still here for backwards compatibility,
# but migrating to MCPToolChoice is recommended.
return { "type": "mcp" } # type: ignore [typeddict-item]
return {"type": "mcp"} # type: ignore [typeddict-item]
else:
return {
"type": "function",
Expand Down
1 change: 1 addition & 0 deletions src/agents/realtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class RealtimeClientMessage(TypedDict):
type: str # explicitly required
other_data: NotRequired[dict[str, Any]]
"""Merged into the message body."""


class RealtimeUserInputText(TypedDict):
Expand Down
11 changes: 11 additions & 0 deletions src/agents/realtime/model_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ class RealtimeModelOtherEvent:
type: Literal["other"] = "other"


@dataclass
class RealtimeModelExceptionEvent:
"""Exception occurred during model operation."""

exception: Exception
context: str | None = None

type: Literal["exception"] = "exception"


# TODO (rm) Add usage events


Expand All @@ -147,4 +157,5 @@ class RealtimeModelOtherEvent:
RealtimeModelTurnStartedEvent,
RealtimeModelTurnEndedEvent,
RealtimeModelOtherEvent,
RealtimeModelExceptionEvent,
]
126 changes: 87 additions & 39 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
RealtimeModelAudioInterruptedEvent,
RealtimeModelErrorEvent,
RealtimeModelEvent,
RealtimeModelExceptionEvent,
RealtimeModelInputAudioTranscriptionCompletedEvent,
RealtimeModelItemDeletedEvent,
RealtimeModelItemUpdatedEvent,
Expand Down Expand Up @@ -130,48 +131,84 @@ async def _listen_for_messages(self):

try:
async for message in self._websocket:
parsed = json.loads(message)
await self._handle_ws_event(parsed)
try:
parsed = json.loads(message)
await self._handle_ws_event(parsed)
except json.JSONDecodeError as e:
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e, context="Failed to parse WebSocket message as JSON"
)
)
except Exception as e:
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e, context="Error handling WebSocket event"
)
)

except websockets.exceptions.ConnectionClosed:
# TODO connection closed handling (event, cleanup)
logger.warning("WebSocket connection closed")
except websockets.exceptions.ConnectionClosedOK:
# Normal connection closure - no exception event needed
logger.info("WebSocket connection closed normally")
except websockets.exceptions.ConnectionClosed as e:
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e, context="WebSocket connection closed unexpectedly"
)
)
except Exception as e:
logger.error(f"WebSocket error: {e}")
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e, context="WebSocket error in message listener"
)
)

async def send_event(self, event: RealtimeClientMessage) -> None:
"""Send an event to the model."""
assert self._websocket is not None, "Not connected"
converted_event = {
"type": event["type"],
}

converted_event.update(event.get("other_data", {}))
try:
converted_event = {
"type": event["type"],
}

await self._websocket.send(json.dumps(converted_event))
converted_event.update(event.get("other_data", {}))

await self._websocket.send(json.dumps(converted_event))
except Exception as e:
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e, context=f"Failed to send event: {event.get('type', 'unknown')}"
)
)

async def send_message(
self, message: RealtimeUserInput, other_event_data: dict[str, Any] | None = None
) -> None:
"""Send a message to the model."""
message = (
message
if isinstance(message, dict)
else {
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": message}],
try:
message = (
message
if isinstance(message, dict)
else {
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": message}],
}
)
other_data = {
"item": message,
}
)
other_data = {
"item": message,
}
if other_event_data:
other_data.update(other_event_data)
if other_event_data:
other_data.update(other_event_data)

await self.send_event({"type": "conversation.item.create", "other_data": other_data})
await self.send_event({"type": "conversation.item.create", "other_data": other_data})

await self.send_event({"type": "response.create"})
await self.send_event({"type": "response.create"})
except Exception as e:
await self._emit_event(
RealtimeModelExceptionEvent(exception=e, context="Failed to send message")
)

async def send_audio(self, audio: bytes, *, commit: bool = False) -> None:
"""Send a raw audio chunk to the model.
Expand All @@ -182,17 +219,23 @@ async def send_audio(self, audio: bytes, *, commit: bool = False) -> None:
detection, this can be used to indicate the turn is completed.
"""
assert self._websocket is not None, "Not connected"
base64_audio = base64.b64encode(audio).decode("utf-8")
await self.send_event(
{
"type": "input_audio_buffer.append",
"other_data": {
"audio": base64_audio,
},
}
)
if commit:
await self.send_event({"type": "input_audio_buffer.commit"})

try:
base64_audio = base64.b64encode(audio).decode("utf-8")
await self.send_event(
{
"type": "input_audio_buffer.append",
"other_data": {
"audio": base64_audio,
},
}
)
if commit:
await self.send_event({"type": "input_audio_buffer.commit"})
except Exception as e:
await self._emit_event(
RealtimeModelExceptionEvent(exception=e, context="Failed to send audio")
)

async def send_tool_output(
self, tool_call: RealtimeModelToolCallEvent, output: str, start_response: bool
Expand Down Expand Up @@ -342,8 +385,13 @@ async def _handle_ws_event(self, event: dict[str, Any]):
OpenAIRealtimeServerEvent
).validate_python(event)
except Exception as e:
logger.error(f"Invalid event: {event} - {e}")
# await self._emit_event(RealtimeModelErrorEvent(error=f"Invalid event: {event} - {e}"))
event_type = event.get("type", "unknown") if isinstance(event, dict) else "unknown"
await self._emit_event(
RealtimeModelExceptionEvent(
exception=e,
context=f"Failed to validate server event: {event_type}",
)
)
return

if parsed.type == "response.audio.delta":
Expand Down
29 changes: 25 additions & 4 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
self._run_config = run_config or {}
self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue()
self._closed = False
self._stored_exception: Exception | None = None

# Guardrails state tracking
self._interrupted_by_guardrail = False
Expand Down Expand Up @@ -130,17 +131,20 @@ async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]:
"""Iterate over events from the session."""
while not self._closed:
try:
# Check if there's a stored exception to raise
if self._stored_exception is not None:
# Clean up resources before raising
await self._cleanup()
raise self._stored_exception

event = await self._event_queue.get()
yield event
except asyncio.CancelledError:
break

async def close(self) -> None:
"""Close the session."""
self._closed = True
self._cleanup_guardrail_tasks()
self._model.remove_listener(self)
await self._model.close()
await self._cleanup()

async def send_message(self, message: RealtimeUserInput) -> None:
"""Send a message to the model."""
Expand Down Expand Up @@ -228,6 +232,9 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
info=self._event_info,
)
)
elif event.type == "exception":
# Store the exception to be raised in __aiter__
self._stored_exception = event.exception
elif event.type == "other":
pass
else:
Expand Down Expand Up @@ -403,3 +410,17 @@ def _cleanup_guardrail_tasks(self) -> None:
if not task.done():
task.cancel()
self._guardrail_tasks.clear()

async def _cleanup(self) -> None:
"""Clean up all resources and mark session as closed."""
# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()

# Remove ourselves as a listener
self._model.remove_listener(self)

# Close the model connection
await self._model.close()

# Mark as closed
self._closed = True
33 changes: 14 additions & 19 deletions tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,39 +174,34 @@ class TestEventHandlingRobustness(TestOpenAIRealtimeWebSocketModel):

@pytest.mark.asyncio
async def test_handle_malformed_json_logs_error_continues(self, model):
"""Test that malformed JSON is logged as error but doesn't crash."""
"""Test that malformed JSON emits exception event but doesn't crash."""
mock_listener = AsyncMock()
model.add_listener(mock_listener)

# Malformed JSON should not crash the handler
with patch("agents.realtime.openai_realtime.logger") as mock_logger:
await model._handle_ws_event("invalid json {")
await model._handle_ws_event("invalid json {")

# Should log error but not crash
mock_logger.error.assert_called_once()
assert "Invalid event" in mock_logger.error.call_args[0][0]

# Should not emit any events to listeners
mock_listener.on_event.assert_not_called()
# Should emit exception event to listeners
mock_listener.on_event.assert_called_once()
exception_event = mock_listener.on_event.call_args[0][0]
assert exception_event.type == "exception"
assert "Failed to validate server event: unknown" in exception_event.context

@pytest.mark.asyncio
async def test_handle_invalid_event_schema_logs_error(self, model):
"""Test that events with invalid schema are logged but don't crash."""
"""Test that events with invalid schema emit exception events but don't crash."""
mock_listener = AsyncMock()
model.add_listener(mock_listener)

invalid_event = {"type": "response.audio.delta"} # Missing required fields

with patch("agents.realtime.openai_realtime.logger") as mock_logger:
await model._handle_ws_event(invalid_event)
await model._handle_ws_event(invalid_event)

# Should log validation error
mock_logger.error.assert_called_once()
error_msg = mock_logger.error.call_args[0][0]
assert "Invalid event" in error_msg

# Should not emit events to listeners
mock_listener.on_event.assert_not_called()
# Should emit exception event to listeners
mock_listener.on_event.assert_called_once()
exception_event = mock_listener.on_event.call_args[0][0]
assert exception_event.type == "exception"
assert "Failed to validate server event: response.audio.delta" in exception_event.context

@pytest.mark.asyncio
async def test_handle_unknown_event_type_ignored(self, model):
Expand Down
Loading