diff --git a/examples/mcp/prompt_server/main.py b/examples/mcp/prompt_server/main.py index 8f2991fc0..4caa95d88 100644 --- a/examples/mcp/prompt_server/main.py +++ b/examples/mcp/prompt_server/main.py @@ -17,7 +17,7 @@ async def get_instructions_from_prompt(mcp_server: MCPServer, prompt_name: str, try: prompt_result = await mcp_server.get_prompt(prompt_name, kwargs) content = prompt_result.messages[0].content - if hasattr(content, 'text'): + if hasattr(content, "text"): instructions = content.text else: instructions = str(content) diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 1e9edcbc6..edb692960 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -42,15 +42,18 @@ def validate_from_none(value: None) -> _Omit: serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None), ) + @dataclass class MCPToolChoice: server_label: str name: str + Omit = Annotated[_Omit, _OmitTypeAnnotation] Headers: TypeAlias = Mapping[str, Union[str, Omit]] ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None] + @dataclass class ModelSettings: """Settings to use when calling an LLM. diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index d25613aee..76c67903c 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -343,7 +343,7 @@ def convert_tool_choice( elif tool_choice == "mcp": # Note that this is still here for backwards compatibility, # but migrating to MCPToolChoice is recommended. - return { "type": "mcp" } # type: ignore [typeddict-item] + return {"type": "mcp"} # type: ignore [typeddict-item] else: return { "type": "function", diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index ece72e755..9333a9cca 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -29,6 +29,7 @@ class RealtimeClientMessage(TypedDict): type: str # explicitly required other_data: NotRequired[dict[str, Any]] + """Merged into the message body.""" class RealtimeUserInputText(TypedDict): diff --git a/src/agents/realtime/model_events.py b/src/agents/realtime/model_events.py index a9769cebd..de1ad5f54 100644 --- a/src/agents/realtime/model_events.py +++ b/src/agents/realtime/model_events.py @@ -130,6 +130,16 @@ class RealtimeModelOtherEvent: type: Literal["other"] = "other" +@dataclass +class RealtimeModelExceptionEvent: + """Exception occurred during model operation.""" + + exception: Exception + context: str | None = None + + type: Literal["exception"] = "exception" + + # TODO (rm) Add usage events @@ -147,4 +157,5 @@ class RealtimeModelOtherEvent: RealtimeModelTurnStartedEvent, RealtimeModelTurnEndedEvent, RealtimeModelOtherEvent, + RealtimeModelExceptionEvent, ] diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 797753242..e239fd003 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -39,6 +39,7 @@ RealtimeModelAudioInterruptedEvent, RealtimeModelErrorEvent, RealtimeModelEvent, + RealtimeModelExceptionEvent, RealtimeModelInputAudioTranscriptionCompletedEvent, RealtimeModelItemDeletedEvent, RealtimeModelItemUpdatedEvent, @@ -130,48 +131,84 @@ async def _listen_for_messages(self): try: async for message in self._websocket: - parsed = json.loads(message) - await self._handle_ws_event(parsed) + try: + parsed = json.loads(message) + await self._handle_ws_event(parsed) + except json.JSONDecodeError as e: + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, context="Failed to parse WebSocket message as JSON" + ) + ) + except Exception as e: + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, context="Error handling WebSocket event" + ) + ) - except websockets.exceptions.ConnectionClosed: - # TODO connection closed handling (event, cleanup) - logger.warning("WebSocket connection closed") + except websockets.exceptions.ConnectionClosedOK: + # Normal connection closure - no exception event needed + logger.info("WebSocket connection closed normally") + except websockets.exceptions.ConnectionClosed as e: + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, context="WebSocket connection closed unexpectedly" + ) + ) except Exception as e: - logger.error(f"WebSocket error: {e}") + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, context="WebSocket error in message listener" + ) + ) async def send_event(self, event: RealtimeClientMessage) -> None: """Send an event to the model.""" assert self._websocket is not None, "Not connected" - converted_event = { - "type": event["type"], - } - converted_event.update(event.get("other_data", {})) + try: + converted_event = { + "type": event["type"], + } - await self._websocket.send(json.dumps(converted_event)) + converted_event.update(event.get("other_data", {})) + + await self._websocket.send(json.dumps(converted_event)) + except Exception as e: + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, context=f"Failed to send event: {event.get('type', 'unknown')}" + ) + ) async def send_message( self, message: RealtimeUserInput, other_event_data: dict[str, Any] | None = None ) -> None: """Send a message to the model.""" - message = ( - message - if isinstance(message, dict) - else { - "type": "message", - "role": "user", - "content": [{"type": "input_text", "text": message}], + try: + message = ( + message + if isinstance(message, dict) + else { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": message}], + } + ) + other_data = { + "item": message, } - ) - other_data = { - "item": message, - } - if other_event_data: - other_data.update(other_event_data) + if other_event_data: + other_data.update(other_event_data) - await self.send_event({"type": "conversation.item.create", "other_data": other_data}) + await self.send_event({"type": "conversation.item.create", "other_data": other_data}) - await self.send_event({"type": "response.create"}) + await self.send_event({"type": "response.create"}) + except Exception as e: + await self._emit_event( + RealtimeModelExceptionEvent(exception=e, context="Failed to send message") + ) async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: """Send a raw audio chunk to the model. @@ -182,17 +219,23 @@ async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: detection, this can be used to indicate the turn is completed. """ assert self._websocket is not None, "Not connected" - base64_audio = base64.b64encode(audio).decode("utf-8") - await self.send_event( - { - "type": "input_audio_buffer.append", - "other_data": { - "audio": base64_audio, - }, - } - ) - if commit: - await self.send_event({"type": "input_audio_buffer.commit"}) + + try: + base64_audio = base64.b64encode(audio).decode("utf-8") + await self.send_event( + { + "type": "input_audio_buffer.append", + "other_data": { + "audio": base64_audio, + }, + } + ) + if commit: + await self.send_event({"type": "input_audio_buffer.commit"}) + except Exception as e: + await self._emit_event( + RealtimeModelExceptionEvent(exception=e, context="Failed to send audio") + ) async def send_tool_output( self, tool_call: RealtimeModelToolCallEvent, output: str, start_response: bool @@ -342,8 +385,13 @@ async def _handle_ws_event(self, event: dict[str, Any]): OpenAIRealtimeServerEvent ).validate_python(event) except Exception as e: - logger.error(f"Invalid event: {event} - {e}") - # await self._emit_event(RealtimeModelErrorEvent(error=f"Invalid event: {event} - {e}")) + event_type = event.get("type", "unknown") if isinstance(event, dict) else "unknown" + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, + context=f"Failed to validate server event: {event_type}", + ) + ) return if parsed.type == "response.audio.delta": diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 8c909fd07..ce8b7d705 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -84,6 +84,7 @@ def __init__( self._run_config = run_config or {} self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue() self._closed = False + self._stored_exception: Exception | None = None # Guardrails state tracking self._interrupted_by_guardrail = False @@ -130,6 +131,12 @@ async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]: """Iterate over events from the session.""" while not self._closed: try: + # Check if there's a stored exception to raise + if self._stored_exception is not None: + # Clean up resources before raising + await self._cleanup() + raise self._stored_exception + event = await self._event_queue.get() yield event except asyncio.CancelledError: @@ -137,10 +144,7 @@ async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]: async def close(self) -> None: """Close the session.""" - self._closed = True - self._cleanup_guardrail_tasks() - self._model.remove_listener(self) - await self._model.close() + await self._cleanup() async def send_message(self, message: RealtimeUserInput) -> None: """Send a message to the model.""" @@ -228,6 +232,9 @@ async def on_event(self, event: RealtimeModelEvent) -> None: info=self._event_info, ) ) + elif event.type == "exception": + # Store the exception to be raised in __aiter__ + self._stored_exception = event.exception elif event.type == "other": pass else: @@ -403,3 +410,17 @@ def _cleanup_guardrail_tasks(self) -> None: if not task.done(): task.cancel() self._guardrail_tasks.clear() + + async def _cleanup(self) -> None: + """Clean up all resources and mark session as closed.""" + # Cancel and cleanup guardrail tasks + self._cleanup_guardrail_tasks() + + # Remove ourselves as a listener + self._model.remove_listener(self) + + # Close the model connection + await self._model.close() + + # Mark as closed + self._closed = True diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 9fe2f9acb..7cd9a5c30 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -174,39 +174,34 @@ class TestEventHandlingRobustness(TestOpenAIRealtimeWebSocketModel): @pytest.mark.asyncio async def test_handle_malformed_json_logs_error_continues(self, model): - """Test that malformed JSON is logged as error but doesn't crash.""" + """Test that malformed JSON emits exception event but doesn't crash.""" mock_listener = AsyncMock() model.add_listener(mock_listener) # Malformed JSON should not crash the handler - with patch("agents.realtime.openai_realtime.logger") as mock_logger: - await model._handle_ws_event("invalid json {") + await model._handle_ws_event("invalid json {") - # Should log error but not crash - mock_logger.error.assert_called_once() - assert "Invalid event" in mock_logger.error.call_args[0][0] - - # Should not emit any events to listeners - mock_listener.on_event.assert_not_called() + # Should emit exception event to listeners + mock_listener.on_event.assert_called_once() + exception_event = mock_listener.on_event.call_args[0][0] + assert exception_event.type == "exception" + assert "Failed to validate server event: unknown" in exception_event.context @pytest.mark.asyncio async def test_handle_invalid_event_schema_logs_error(self, model): - """Test that events with invalid schema are logged but don't crash.""" + """Test that events with invalid schema emit exception events but don't crash.""" mock_listener = AsyncMock() model.add_listener(mock_listener) invalid_event = {"type": "response.audio.delta"} # Missing required fields - with patch("agents.realtime.openai_realtime.logger") as mock_logger: - await model._handle_ws_event(invalid_event) + await model._handle_ws_event(invalid_event) - # Should log validation error - mock_logger.error.assert_called_once() - error_msg = mock_logger.error.call_args[0][0] - assert "Invalid event" in error_msg - - # Should not emit events to listeners - mock_listener.on_event.assert_not_called() + # Should emit exception event to listeners + mock_listener.on_event.assert_called_once() + exception_event = mock_listener.on_event.call_args[0][0] + assert exception_event.type == "exception" + assert "Failed to validate server event: response.audio.delta" in exception_event.context @pytest.mark.asyncio async def test_handle_unknown_event_type_ignored(self, model): diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 2cd71b023..f6bd60064 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -974,27 +974,30 @@ class TestGuardrailFunctionality: async def _wait_for_guardrail_tasks(self, session): """Wait for all pending guardrail tasks to complete.""" import asyncio + if session._guardrail_tasks: await asyncio.gather(*session._guardrail_tasks, return_exceptions=True) @pytest.fixture def triggered_guardrail(self): """Creates a guardrail that always triggers""" + def guardrail_func(context, agent, output): return GuardrailFunctionOutput( - output_info={"reason": "test trigger"}, - tripwire_triggered=True + output_info={"reason": "test trigger"}, tripwire_triggered=True ) + return OutputGuardrail(guardrail_function=guardrail_func, name="triggered_guardrail") @pytest.fixture def safe_guardrail(self): """Creates a guardrail that never triggers""" + def guardrail_func(context, agent, output): return GuardrailFunctionOutput( - output_info={"reason": "safe content"}, - tripwire_triggered=False + output_info={"reason": "safe content"}, tripwire_triggered=False ) + return OutputGuardrail(guardrail_function=guardrail_func, name="safe_guardrail") @pytest.mark.asyncio @@ -1004,7 +1007,7 @@ async def test_transcript_delta_triggers_guardrail_at_threshold( """Test that guardrails run when transcript delta reaches debounce threshold""" run_config: RealtimeRunConfig = { "output_guardrails": [triggered_guardrail], - "guardrails_settings": {"debounce_text_length": 10} + "guardrails_settings": {"debounce_text_length": 10}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) @@ -1041,20 +1044,20 @@ async def test_transcript_delta_multiple_thresholds_same_item( """Test guardrails run at 1x, 2x, 3x thresholds for same item_id""" run_config: RealtimeRunConfig = { "output_guardrails": [triggered_guardrail], - "guardrails_settings": {"debounce_text_length": 5} + "guardrails_settings": {"debounce_text_length": 5}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) # First delta - reaches 1x threshold (5 chars) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="12345", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="12345", response_id="resp_1") + ) # Second delta - reaches 2x threshold (10 chars total) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="67890", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="67890", response_id="resp_1") + ) # Wait for async guardrail tasks to complete await self._wait_for_guardrail_tasks(session) @@ -1070,28 +1073,32 @@ async def test_transcript_delta_different_items_tracked_separately( """Test that different item_ids are tracked separately for debouncing""" run_config: RealtimeRunConfig = { "output_guardrails": [safe_guardrail], - "guardrails_settings": {"debounce_text_length": 10} + "guardrails_settings": {"debounce_text_length": 10}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) # Add text to item_1 (8 chars - below threshold) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="12345678", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="12345678", response_id="resp_1" + ) + ) # Add text to item_2 (8 chars - below threshold) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_2", delta="abcdefgh", response_id="resp_2" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_2", delta="abcdefgh", response_id="resp_2" + ) + ) # Neither should trigger guardrails yet assert mock_model.interrupts_called == 0 # Add more text to item_1 (total 12 chars - above threshold) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="90ab", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="90ab", response_id="resp_1") + ) # item_1 should have triggered guardrail run (but not interrupted since safe) assert session._item_guardrail_run_counts["item_1"] == 1 @@ -1107,15 +1114,17 @@ async def test_turn_ended_clears_guardrail_state( """Test that turn_ended event clears guardrail state for next turn""" run_config: RealtimeRunConfig = { "output_guardrails": [triggered_guardrail], - "guardrails_settings": {"debounce_text_length": 5} + "guardrails_settings": {"debounce_text_length": 5}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) # Trigger guardrail - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="trigger", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + ) + ) # Wait for async guardrail tasks to complete await self._wait_for_guardrail_tasks(session) @@ -1132,16 +1141,13 @@ async def test_turn_ended_clears_guardrail_state( assert len(session._item_guardrail_run_counts) == 0 @pytest.mark.asyncio - async def test_multiple_guardrails_all_triggered( - self, mock_model, mock_agent - ): + async def test_multiple_guardrails_all_triggered(self, mock_model, mock_agent): """Test that all triggered guardrails are included in the event""" + def create_triggered_guardrail(name): def guardrail_func(context, agent, output): - return GuardrailFunctionOutput( - output_info={"name": name}, - tripwire_triggered=True - ) + return GuardrailFunctionOutput(output_info={"name": name}, tripwire_triggered=True) + return OutputGuardrail(guardrail_function=guardrail_func, name=name) guardrail1 = create_triggered_guardrail("guardrail_1") @@ -1149,14 +1155,16 @@ def guardrail_func(context, agent, output): run_config: RealtimeRunConfig = { "output_guardrails": [guardrail1, guardrail2], - "guardrails_settings": {"debounce_text_length": 5} + "guardrails_settings": {"debounce_text_length": 5}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="trigger", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + ) + ) # Wait for async guardrail tasks to complete await self._wait_for_guardrail_tasks(session) diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py index 456ae125f..7fb2594a9 100644 --- a/tests/realtime/test_tracing.py +++ b/tests/realtime/test_tracing.py @@ -222,27 +222,24 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket): # Create a test agent and runner with tracing disabled agent = RealtimeAgent(name="test_agent", instructions="test") - runner = RealtimeRunner( - starting_agent=agent, - config={"tracing_disabled": True} - ) + runner = RealtimeRunner(starting_agent=agent, config={"tracing_disabled": True}) # Test the _get_model_settings method directly since that's where the logic is model_settings = await runner._get_model_settings( agent=agent, disable_tracing=True, # This should come from config["tracing_disabled"] initial_settings=None, - overrides=None + overrides=None, ) # When tracing is disabled, model settings should have tracing=None assert model_settings["tracing"] is None # Also test that the runner passes disable_tracing=True correctly - with patch.object(runner, '_get_model_settings') as mock_get_settings: + with patch.object(runner, "_get_model_settings") as mock_get_settings: mock_get_settings.return_value = {"tracing": None} - with patch('agents.realtime.session.RealtimeSession') as mock_session_class: + with patch("agents.realtime.session.RealtimeSession") as mock_session_class: mock_session = AsyncMock() mock_session_class.return_value = mock_session @@ -250,8 +247,5 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket): # Verify that _get_model_settings was called with disable_tracing=True mock_get_settings.assert_called_once_with( - agent=agent, - disable_tracing=True, - initial_settings=None, - overrides=None + agent=agent, disable_tracing=True, initial_settings=None, overrides=None ) diff --git a/tests/test_session_exceptions.py b/tests/test_session_exceptions.py new file mode 100644 index 000000000..c49c179aa --- /dev/null +++ b/tests/test_session_exceptions.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import asyncio +import json +from typing import Any +from unittest.mock import AsyncMock, Mock + +import pytest +import websockets.exceptions + +from agents.realtime.events import RealtimeError +from agents.realtime.model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener +from agents.realtime.model_events import ( + RealtimeModelErrorEvent, + RealtimeModelEvent, + RealtimeModelExceptionEvent, +) +from agents.realtime.session import RealtimeSession + + +class FakeRealtimeModel(RealtimeModel): + """Fake model for testing that forwards events to listeners.""" + + def __init__(self): + self._listeners: list[RealtimeModelListener] = [] + self._events_to_send: list[RealtimeModelEvent] = [] + self._is_connected = False + self._send_task: asyncio.Task[None] | None = None + + def set_next_events(self, events: list[RealtimeModelEvent]) -> None: + """Set events to be sent to listeners.""" + self._events_to_send = events.copy() + + async def connect(self, options: RealtimeModelConfig) -> None: + """Fake connection that starts sending events.""" + self._is_connected = True + self._send_task = asyncio.create_task(self._send_events()) + + async def _send_events(self) -> None: + """Send queued events to all listeners.""" + for event in self._events_to_send: + await asyncio.sleep(0.001) # Small delay to simulate async behavior + for listener in self._listeners: + await listener.on_event(event) + + def add_listener(self, listener: RealtimeModelListener) -> None: + """Add a listener.""" + self._listeners.append(listener) + + def remove_listener(self, listener: RealtimeModelListener) -> None: + """Remove a listener.""" + if listener in self._listeners: + self._listeners.remove(listener) + + async def close(self) -> None: + """Close the fake model.""" + self._is_connected = False + if self._send_task and not self._send_task.done(): + self._send_task.cancel() + try: + await self._send_task + except asyncio.CancelledError: + pass + + async def send_message( + self, message: Any, other_event_data: dict[str, Any] | None = None + ) -> None: + """Fake send message.""" + pass + + async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: + """Fake send audio.""" + pass + + async def send_event(self, event: Any) -> None: + """Fake send event.""" + pass + + async def send_tool_output(self, tool_call: Any, output: str, start_response: bool) -> None: + """Fake send tool output.""" + pass + + async def interrupt(self) -> None: + """Fake interrupt.""" + pass + + +@pytest.fixture +def fake_agent(): + """Create a fake agent for testing.""" + agent = Mock() + agent.get_all_tools = AsyncMock(return_value=[]) + return agent + + +@pytest.fixture +def fake_model(): + """Create a fake model for testing.""" + return FakeRealtimeModel() + + +class TestSessionExceptions: + """Test exception handling in RealtimeSession.""" + + @pytest.mark.asyncio + async def test_end_to_end_exception_propagation_and_cleanup( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test that exceptions are stored, trigger cleanup, and are raised in __aiter__.""" + # Create test exception + test_exception = ValueError("Test error") + exception_event = RealtimeModelExceptionEvent( + exception=test_exception, context="Test context" + ) + + # Set up session + session = RealtimeSession(fake_model, fake_agent, None) + + # Set events to send + fake_model.set_next_events([exception_event]) + + # Start session + async with session: + # Try to iterate and expect exception + with pytest.raises(ValueError, match="Test error"): + async for _event in session: + pass # Should never reach here + + # Verify cleanup occurred + assert session._closed is True + assert session._stored_exception == test_exception + assert fake_model._is_connected is False + assert len(fake_model._listeners) == 0 + + @pytest.mark.asyncio + async def test_websocket_connection_closure_type_distinction( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test different WebSocket closure types generate appropriate events.""" + # Test ConnectionClosed (should create exception event) + error_closure = websockets.exceptions.ConnectionClosed(None, None) + error_event = RealtimeModelExceptionEvent( + exception=error_closure, context="WebSocket connection closed unexpectedly" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([error_event]) + + with pytest.raises(websockets.exceptions.ConnectionClosed): + async with session: + async for _event in session: + pass + + # Verify error closure triggered cleanup + assert session._closed is True + assert isinstance(session._stored_exception, websockets.exceptions.ConnectionClosed) + + @pytest.mark.asyncio + async def test_json_parsing_error_handling(self, fake_model: FakeRealtimeModel, fake_agent): + """Test JSON parsing errors are properly handled and contextualized.""" + # Create JSON decode error + json_error = json.JSONDecodeError("Invalid JSON", "bad json", 0) + json_exception_event = RealtimeModelExceptionEvent( + exception=json_error, context="Failed to parse WebSocket message as JSON" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([json_exception_event]) + + with pytest.raises(json.JSONDecodeError): + async with session: + async for _event in session: + pass + + # Verify context is preserved + assert session._stored_exception == json_error + assert session._closed is True + + @pytest.mark.asyncio + async def test_exception_context_preservation(self, fake_model: FakeRealtimeModel, fake_agent): + """Test that exception context information is preserved through the handling process.""" + test_contexts = [ + ("Failed to send audio", RuntimeError("Audio encoding failed")), + ("WebSocket error in message listener", ConnectionError("Network error")), + ("Failed to send event: response.create", OSError("Socket closed")), + ] + + for context, exception in test_contexts: + exception_event = RealtimeModelExceptionEvent(exception=exception, context=context) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([exception_event]) + + with pytest.raises(type(exception)): + async with session: + async for _event in session: + pass + + # Verify the exact exception is stored + assert session._stored_exception == exception + assert session._closed is True + + # Reset for next iteration + fake_model._is_connected = False + fake_model._listeners.clear() + + @pytest.mark.asyncio + async def test_multiple_exception_handling_behavior( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test behavior when multiple exceptions occur before consumption.""" + # Create multiple exceptions + first_exception = ValueError("First error") + second_exception = RuntimeError("Second error") + + first_event = RealtimeModelExceptionEvent( + exception=first_exception, context="First context" + ) + second_event = RealtimeModelExceptionEvent( + exception=second_exception, context="Second context" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([first_event, second_event]) + + # Start session and let events process + async with session: + # Give time for events to be processed + await asyncio.sleep(0.05) + + # The first exception should be stored (second should overwrite, but that's + # the current behavior). In practice, once an exception occurs, cleanup + # should prevent further processing + assert session._stored_exception is not None + assert session._closed is True + + @pytest.mark.asyncio + async def test_exception_during_guardrail_processing( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test that exceptions don't interfere with guardrail task cleanup.""" + # Create exception event + test_exception = RuntimeError("Processing error") + exception_event = RealtimeModelExceptionEvent( + exception=test_exception, context="Processing failed" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + + # Add some fake guardrail tasks + fake_task1 = Mock() + fake_task1.done.return_value = False + fake_task1.cancel = Mock() + + fake_task2 = Mock() + fake_task2.done.return_value = True + fake_task2.cancel = Mock() + + session._guardrail_tasks = {fake_task1, fake_task2} + + fake_model.set_next_events([exception_event]) + + with pytest.raises(RuntimeError, match="Processing error"): + async with session: + async for _event in session: + pass + + # Verify guardrail tasks were properly cleaned up + fake_task1.cancel.assert_called_once() + fake_task2.cancel.assert_not_called() # Already done + assert len(session._guardrail_tasks) == 0 + + @pytest.mark.asyncio + async def test_normal_events_still_work_before_exception( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test that normal events are processed before an exception occurs.""" + # Create normal event followed by exception + normal_event = RealtimeModelErrorEvent(error={"message": "Normal error"}) + exception_event = RealtimeModelExceptionEvent( + exception=ValueError("Fatal error"), context="Fatal context" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([normal_event, exception_event]) + + events_received = [] + + with pytest.raises(ValueError, match="Fatal error"): + async with session: + async for event in session: + events_received.append(event) + + # Should have received events before exception + assert len(events_received) >= 1 + # Look for the error event (might not be first due to history_updated + # being emitted initially) + error_events = [e for e in events_received if hasattr(e, "type") and e.type == "error"] + assert len(error_events) >= 1 + assert isinstance(error_events[0], RealtimeError)