diff --git a/src/agents/realtime/events.py b/src/agents/realtime/events.py index e1b3cfea3..24444b66e 100644 --- a/src/agents/realtime/events.py +++ b/src/agents/realtime/events.py @@ -197,6 +197,7 @@ class RealtimeGuardrailTripped: type: Literal["guardrail_tripped"] = "guardrail_tripped" + RealtimeSessionEvent: TypeAlias = Union[ RealtimeAgentStartEvent, RealtimeAgentEndEvent, diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 0ca3bf7af..8c909fd07 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -93,6 +93,8 @@ def __init__( "debounce_text_length", 100 ) + self._guardrail_tasks: set[asyncio.Task[Any]] = set() + async def __aenter__(self) -> RealtimeSession: """Start the session by connecting to the model. After this, you will be able to stream events from the model and send messages and audio to the model. @@ -136,6 +138,7 @@ async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]: async def close(self) -> None: """Close the session.""" self._closed = True + self._cleanup_guardrail_tasks() self._model.remove_listener(self) await self._model.close() @@ -185,7 +188,7 @@ async def on_event(self, event: RealtimeModelEvent) -> None: if current_length >= next_run_threshold: self._item_guardrail_run_counts[item_id] += 1 - await self._run_output_guardrails(self._item_transcripts[item_id]) + self._enqueue_guardrail_task(self._item_transcripts[item_id]) elif event.type == "item_updated": is_new = not any(item.item_id == event.item.item_id for item in self._history) self._history = self._get_new_history(self._history, event.item) @@ -366,3 +369,37 @@ async def _run_output_guardrails(self, text: str) -> bool: return True return False + + def _enqueue_guardrail_task(self, text: str) -> None: + # Runs the guardrails in a separate task to avoid blocking the main loop + + task = asyncio.create_task(self._run_output_guardrails(text)) + self._guardrail_tasks.add(task) + + # Add callback to remove completed tasks and handle exceptions + task.add_done_callback(self._on_guardrail_task_done) + + def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: + """Handle completion of a guardrail task.""" + # Remove from tracking set + self._guardrail_tasks.discard(task) + + # Check for exceptions and propagate as events + if not task.cancelled(): + exception = task.exception() + if exception: + # Create an exception event instead of raising + asyncio.create_task( + self._put_event( + RealtimeError( + info=self._event_info, + error={"message": f"Guardrail task failed: {str(exception)}"}, + ) + ) + ) + + def _cleanup_guardrail_tasks(self) -> None: + for task in self._guardrail_tasks: + if not task.done(): + task.cancel() + self._guardrail_tasks.clear() diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 6003f9443..2cd71b023 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -971,6 +971,12 @@ async def test_mixed_tool_types_filtering(self, mock_model, mock_agent): class TestGuardrailFunctionality: """Test suite for output guardrail functionality in RealtimeSession""" + async def _wait_for_guardrail_tasks(self, session): + """Wait for all pending guardrail tasks to complete.""" + import asyncio + if session._guardrail_tasks: + await asyncio.gather(*session._guardrail_tasks, return_exceptions=True) + @pytest.fixture def triggered_guardrail(self): """Creates a guardrail that always triggers""" @@ -1010,6 +1016,9 @@ async def test_transcript_delta_triggers_guardrail_at_threshold( await session.on_event(transcript_event) + # Wait for async guardrail tasks to complete + await self._wait_for_guardrail_tasks(session) + # Should have triggered guardrail and interrupted assert session._interrupted_by_guardrail is True assert mock_model.interrupts_called == 1 @@ -1047,6 +1056,9 @@ async def test_transcript_delta_multiple_thresholds_same_item( item_id="item_1", delta="67890", response_id="resp_1" )) + # Wait for async guardrail tasks to complete + await self._wait_for_guardrail_tasks(session) + # Should only trigger once due to interrupted_by_guardrail flag assert mock_model.interrupts_called == 1 assert len(mock_model.sent_messages) == 1 @@ -1105,6 +1117,9 @@ async def test_turn_ended_clears_guardrail_state( item_id="item_1", delta="trigger", response_id="resp_1" )) + # Wait for async guardrail tasks to complete + await self._wait_for_guardrail_tasks(session) + assert session._interrupted_by_guardrail is True assert len(session._item_transcripts) == 1 @@ -1143,6 +1158,9 @@ def guardrail_func(context, agent, output): item_id="item_1", delta="trigger", response_id="resp_1" )) + # Wait for async guardrail tasks to complete + await self._wait_for_guardrail_tasks(session) + # Should have interrupted and sent message with both guardrail names assert mock_model.interrupts_called == 1 assert len(mock_model.sent_messages) == 1