From fac7da74d3ca54b87bd169cbd3125e799d500c7d Mon Sep 17 00:00:00 2001 From: Viraj Date: Thu, 21 Aug 2025 21:16:34 -0700 Subject: [PATCH] feat(run): add lifecycle interrupt + inject, cancel-aware tools, safer tracing --- src/agents/_run_impl.py | 99 +++++-- src/agents/models/openai_responses.py | 24 +- src/agents/result.py | 73 +++++- src/agents/run.py | 358 +++++++++++++++++++++++--- src/agents/stream_events.py | 21 +- tests/test_run_lifecycle.py | 320 +++++++++++++++++++++++ 6 files changed, 823 insertions(+), 72 deletions(-) create mode 100644 tests/test_run_lifecycle.py diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 6c417b308..56950bdda 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import dataclasses import inspect from collections.abc import Awaitable @@ -226,6 +227,29 @@ def get_model_tracing_impl( return ModelTracing.ENABLED_WITHOUT_DATA +# Helpers for cancellable tool execution + + +async def _await_cancellable(awaitable): + """Await an awaitable in its own task so CancelledError interrupts promptly.""" + task = asyncio.create_task(awaitable) + try: + return await task + except asyncio.CancelledError: + # propagate so run.py can handle terminal cancel + raise + + +def _maybe_call_cancel_hook(tool_obj) -> None: + """Best-effort: call a cancel/terminate hook on the tool if present.""" + for name in ("cancel", "terminate", "stop"): + cb = getattr(tool_obj, name, None) + if callable(cb): + with contextlib.suppress(Exception): + cb() + break + + class RunImpl: @classmethod async def execute_tools_and_side_effects( @@ -556,16 +580,24 @@ async def run_single_tool( if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: - _, _, result = await asyncio.gather( + # run start hooks first (don’t tie them to the cancellable task) + await asyncio.gather( hooks.on_tool_start(tool_context, agent, func_tool), ( agent.hooks.on_tool_start(tool_context, agent, func_tool) if agent.hooks else _coro.noop_coroutine() ), - func_tool.on_invoke_tool(tool_context, tool_call.arguments), ) + try: + result = await _await_cancellable( + func_tool.on_invoke_tool(tool_context, tool_call.arguments) + ) + except asyncio.CancelledError: + _maybe_call_cancel_hook(func_tool) + raise + await asyncio.gather( hooks.on_tool_end(tool_context, agent, func_tool, result), ( @@ -574,6 +606,7 @@ async def run_single_tool( else _coro.noop_coroutine() ), ) + except Exception as e: _error_tracing.attach_error_to_current_span( SpanError( @@ -644,7 +677,6 @@ async def execute_computer_actions( config: RunConfig, ) -> list[RunItem]: results: list[RunItem] = [] - # Need to run these serially, because each action can affect the computer state for action in actions: acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check: @@ -661,24 +693,28 @@ async def execute_computer_actions( if ack: acknowledged.append( ComputerCallOutputAcknowledgedSafetyCheck( - id=check.id, - code=check.code, - message=check.message, + id=check.id, code=check.code, message=check.message ) ) else: raise UserError("Computer tool safety check was not acknowledged") - results.append( - await ComputerAction.execute( - agent=agent, - action=action, - hooks=hooks, - context_wrapper=context_wrapper, - config=config, - acknowledged_safety_checks=acknowledged, + try: + item = await _await_cancellable( + ComputerAction.execute( + agent=agent, + action=action, + hooks=hooks, + context_wrapper=context_wrapper, + config=config, + acknowledged_safety_checks=acknowledged, + ) ) - ) + except asyncio.CancelledError: + _maybe_call_cancel_hook(action.computer_tool) + raise + + results.append(item) return results @@ -1052,16 +1088,23 @@ async def execute( else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call) ) - _, _, output = await asyncio.gather( + # start hooks first + await asyncio.gather( hooks.on_tool_start(context_wrapper, agent, action.computer_tool), ( agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool) if agent.hooks else _coro.noop_coroutine() ), - output_func, ) - + # run the action (screenshot/etc) in a cancellable task + try: + output = await _await_cancellable(output_func) + except asyncio.CancelledError: + _maybe_call_cancel_hook(action.computer_tool) + raise + + # end hooks await asyncio.gather( hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output), ( @@ -1169,10 +1212,20 @@ async def execute( data=call.tool_call, ) output = call.local_shell_tool.executor(request) - if inspect.isawaitable(output): - result = await output - else: - result = output + try: + if inspect.isawaitable(output): + result = await _await_cancellable(output) + else: + # If executor returns a sync result, just use it (can’t cancel mid-call) + result = output + except asyncio.CancelledError: + # Best-effort: if the executor or tool exposes a cancel/terminate, call it + _maybe_call_cancel_hook(call.local_shell_tool) + # If your executor returns a proc handle (common pattern), adddress it here if needed: + # with contextlib.suppress(Exception): + # proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0) + # proc.kill() + raise await asyncio.gather( hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), @@ -1185,7 +1238,7 @@ async def execute( return ToolCallOutputItem( agent=agent, - output=output, + output=result, raw_item={ "type": "local_shell_call_output", "id": call.tool_call.call_id, diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 4352c99c7..34aaf3829 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json from collections.abc import AsyncIterator from dataclasses import dataclass @@ -172,15 +173,30 @@ async def stream_response( final_response: Response | None = None - async for chunk in stream: - if isinstance(chunk, ResponseCompletedEvent): - final_response = chunk.response - yield chunk + try: + async for chunk in stream: # ensure type checkers relax here + if isinstance(chunk, ResponseCompletedEvent): + final_response = chunk.response + yield chunk + except asyncio.CancelledError: + # Cooperative cancel: ensure the HTTP stream is closed, then propagate + try: + await stream.close() + except Exception: + pass + raise + finally: + # Always close the stream if the async iterator exits (normal or error) + try: + await stream.close() + except Exception: + pass if final_response and tracing.include_data(): span_response.span_data.response = final_response span_response.span_data.input = input + except Exception as e: span_response.set_error( SpanError( diff --git a/src/agents/result.py b/src/agents/result.py index 5cf0e74c8..1d259e36b 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -2,6 +2,7 @@ import abc import asyncio +import contextlib from collections.abc import AsyncIterator from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, cast @@ -143,6 +144,12 @@ class RunResultStreaming(RunResultBase): is_complete: bool = False """Whether the agent has finished running.""" + _emit_status_events: bool = False + """Whether to emit RunUpdatedStreamEvent status updates. + + Defaults to False for backward compatibility. + """ + # Queues that the background run_loop writes to _event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field( default_factory=asyncio.Queue, repr=False @@ -164,17 +171,45 @@ def last_agent(self) -> Agent[Any]: """ return self.current_agent - def cancel(self) -> None: - """Cancels the streaming run, stopping all background tasks and marking the run as - complete.""" - self._cleanup_tasks() # Cancel all running tasks - self.is_complete = True # Mark the run as complete to stop event streaming + def cancel(self, reason: str | None = None) -> None: + # 1) Signal cooperative cancel to the runner + active = getattr(self, "_active_run", None) + if active: + with contextlib.suppress(Exception): + active.cancel(reason) + # 2) Do NOT cancel the background task; let the loop unwind cooperatively + # task = getattr(self, "_run_impl_task", None) + # if task and not task.done(): + # with contextlib.suppress(Exception): + # task.cancel() + + # 4) Mark complete; flushing only when status events are disabled + self.is_complete = True + if not getattr(self, "_emit_status_events", False): + with contextlib.suppress(Exception): + while not self._event_queue.empty(): + self._event_queue.get_nowait() + self._event_queue.task_done() + with contextlib.suppress(Exception): + while not self._input_guardrail_queue.empty(): + self._input_guardrail_queue.get_nowait() + self._input_guardrail_queue.task_done() + + def inject(self, items: list[TResponseInputItem]) -> None: + """ + Inject new input items mid-run. They will be consumed at the start of the next step. + """ + active = getattr(self, "_active_run", None) + if active is not None: + try: + active.inject(items) + except Exception: + pass - # Optionally, clear the event queue to prevent processing stale events - while not self._event_queue.empty(): - self._event_queue.get_nowait() - while not self._input_guardrail_queue.empty(): - self._input_guardrail_queue.get_nowait() + @property + def active_run(self): + """Access the underlying ActiveRun handle (may be None early in startup).""" + return getattr(self, "_active_run", None) async def stream_events(self) -> AsyncIterator[StreamEvent]: """Stream deltas for new items as they are generated. We're using the types from the @@ -243,21 +278,33 @@ def _check_errors(self): # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): run_impl_exc = self._run_impl_task.exception() - if run_impl_exc and isinstance(run_impl_exc, Exception): + if ( + run_impl_exc + and isinstance(run_impl_exc, Exception) + and not isinstance(run_impl_exc, asyncio.CancelledError) + ): if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None: run_impl_exc.run_data = self._create_error_details() self._stored_exception = run_impl_exc if self._input_guardrails_task and self._input_guardrails_task.done(): in_guard_exc = self._input_guardrails_task.exception() - if in_guard_exc and isinstance(in_guard_exc, Exception): + if ( + in_guard_exc + and isinstance(in_guard_exc, Exception) + and not isinstance(in_guard_exc, asyncio.CancelledError) + ): if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None: in_guard_exc.run_data = self._create_error_details() self._stored_exception = in_guard_exc if self._output_guardrails_task and self._output_guardrails_task.done(): out_guard_exc = self._output_guardrails_task.exception() - if out_guard_exc and isinstance(out_guard_exc, Exception): + if ( + out_guard_exc + and isinstance(out_guard_exc, Exception) + and not isinstance(out_guard_exc, asyncio.CancelledError) + ): if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None: out_guard_exc.run_data = self._create_error_details() self._stored_exception = out_guard_exc diff --git a/src/agents/run.py b/src/agents/run.py index e63d7751e..bd5510501 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -49,7 +49,7 @@ from .models.multi_provider import MultiProvider from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext -from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent +from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunUpdatedStreamEvent from .tool import Tool from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData @@ -81,6 +81,55 @@ def get_default_agent_runner() -> AgentRunner: return DEFAULT_AGENT_RUNNER +# Cooperative cancellation + active handle + + +class _Cancellation: + def __init__(self): + self._ev = asyncio.Event() + self.reason: str | None = None + + def start(self, reason: str | None = None): + if not self._ev.is_set(): + self.reason = reason + self._ev.set() + + def is_cancelling(self) -> bool: + return self._ev.is_set() + + def raise_if_cancelled(self): + if self.is_cancelling(): + raise asyncio.CancelledError(self.reason or "Cancelled") + + +class _ActiveRun: + """ + Lightweight handle exposed to results so callers can cancel or inject input. + NOTE: the `RunResultStreaming` will keep a reference to this. + """ + + def __init__( + self, + cancel: _Cancellation, + inbox: list[TResponseInputItem], + state_cb: Callable[[], dict[str, Any]], + ): + self._cancel = cancel + self._inbox = inbox + self._state_cb = state_cb + + def cancel(self, reason: str | None = None) -> None: + self._cancel.start(reason) + + def inject(self, items: list[TResponseInputItem]) -> None: + # Append external inputs; consumed at the start of the next step + self._inbox.extend(items) + + def state(self) -> dict[str, Any]: + # optional: expose minimal state for debugging + return self._state_cb() + + @dataclass class ModelInputData: """Container for the data that will be sent to the model.""" @@ -356,6 +405,21 @@ class AgentRunner: It should not be used directly or subclassed. """ + @staticmethod + def _safe_finish(obj, *, reset_current: bool = True) -> None: + """ + Finish a span/trace safely even if called from a different task context. + Tries reset_current=True first; falls back to reset_current=False if needed. + """ + try: + obj.finish(reset_current=reset_current) + except Exception: + try: + obj.finish(reset_current=False) + except Exception: + # Last-resort: suppress exceptions since we are already tearing down. + pass + async def run( self, starting_agent: Agent[TContext], @@ -376,6 +440,18 @@ async def run( # Prepare input with session if enabled prepared_input = await self._prepare_input_with_session(input, session) + # Cancellation + inbox + handle for non-streamed runs + cancel_token = _Cancellation() + inbox: list[TResponseInputItem] = [] + + def _state_cb() -> dict[str, Any]: + return { + "current_turn": 0, # we'll update this below + "inbox_len": len(inbox), + } + + active_run = _ActiveRun(cancel_token, inbox, _state_cb) + tool_use_tracker = AgentToolUseTracker() with TraceCtxManager( @@ -386,6 +462,11 @@ async def run( disabled=run_config.tracing_disabled, ): current_turn = 0 + + def _update_state_turn(n: int): + _state_cb_dict = active_run.state() + _state_cb_dict["current_turn"] = n # optional, purely for debugging + original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input) generated_items: list[RunItem] = [] model_responses: list[ModelResponse] = [] @@ -393,6 +474,9 @@ async def run( context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( context=context, # type: ignore ) + # Stash inbox + cancel token on context wrapper for internal access + cast(Any, context_wrapper)._inbox = inbox + cast(Any, context_wrapper)._cancel_token = cancel_token input_guardrail_results: list[InputGuardrailResult] = [] @@ -404,6 +488,12 @@ async def run( while True: all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper) + # Cooperative cancel at loop top + if cancel_token.is_cancelling(): + raise asyncio.CancelledError(cancel_token.reason or "Cancelled") + + _update_state_turn(current_turn) + # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: @@ -500,11 +590,12 @@ async def run( # Save the conversation to session if enabled await self._save_result_to_session(session, input, result) + cast(Any, result).active_run = active_run # expose handle on non-streamed return result elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) - current_span.finish(reset_current=True) + AgentRunner._safe_finish(current_span, reset_current=True) current_span = None should_run_agent_start_hooks = True elif isinstance(turn_result.next_step, NextStepRunAgain): @@ -513,6 +604,24 @@ async def run( raise AgentsException( f"Unknown next step type: {type(turn_result.next_step)}" ) + + except asyncio.CancelledError as _c: + # Produce a terminal cancelled result; mirror the RunResult shape + result = RunResult( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + final_output=None, + _last_agent=current_agent, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=[], + context_wrapper=context_wrapper, + ) + # Save to session if enabled + cast(Any, result).active_run = active_run + await self._save_result_to_session(session, input, result) + return result + except AgentsException as exc: exc.run_data = RunErrorDetails( input=original_input, @@ -526,7 +635,7 @@ async def run( raise finally: if current_span: - current_span.finish(reset_current=True) + AgentRunner._safe_finish(current_span, reset_current=True) def run_sync( self, @@ -608,6 +717,27 @@ def run_streamed( context_wrapper=context_wrapper, ) + # Cancellation + inbox + handle for this streamed run + cancel_token = _Cancellation() + inbox: list[TResponseInputItem] = [] + + # A tiny state closure for debugging/inspection + def _state_cb() -> dict[str, Any]: + current = getattr(streamed_result, "current_agent", None) + return { + "current_agent": current.name if current else None, + "current_turn": streamed_result.current_turn, + "is_complete": streamed_result.is_complete, + "inbox_len": len(inbox), + } + + active_run = _ActiveRun(cancel_token, inbox, _state_cb) + + # Stash these on the streamed_result; you'll expose helpers in result.py + cast(Any, streamed_result)._active_run = active_run + cast(Any, streamed_result)._cancel_token = cancel_token + cast(Any, streamed_result)._inbox = inbox + # Kick off the actual agent loop in the background and return the streamed result object. streamed_result._run_impl_task = asyncio.create_task( self._start_streaming( @@ -620,8 +750,11 @@ def run_streamed( run_config=run_config, previous_response_id=previous_response_id, session=session, + _cancel_token=cancel_token, + _inbox=inbox, ) ) + return streamed_result @classmethod @@ -720,6 +853,8 @@ async def _start_streaming( run_config: RunConfig, previous_response_id: str | None, session: Session | None, + _cancel_token: _Cancellation, + _inbox: list[TResponseInputItem], ): if streamed_result.trace: streamed_result.trace.start(mark_as_current=True) @@ -732,6 +867,10 @@ async def _start_streaming( streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) + # Track whether we've already closed span/trace in a special path (GeneratorExit) + span_finished = False + trace_finished = False + try: # Prepare input with session if enabled prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session) @@ -740,6 +879,16 @@ async def _start_streaming( streamed_result.input = prepared_input while True: + # Cooperative cancel at loop top + if _cancel_token.is_cancelling(): + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="cancelled", reason=_cancel_token.reason) + ) + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + if streamed_result.is_complete: break @@ -765,6 +914,7 @@ async def _start_streaming( current_span.start(mark_as_current=True) tool_names = [t.name for t in all_tools] current_span.span_data.tools = tool_names + current_turn += 1 streamed_result.current_turn = current_turn @@ -776,7 +926,15 @@ async def _start_streaming( data={"max_turns": max_turns}, ), ) - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent( + status="failed", reason=f"Max turns exceeded ({max_turns})" + ) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) break if current_turn == 1: @@ -813,12 +971,18 @@ async def _start_streaming( if isinstance(turn_result.next_step, NextStepHandoff): current_agent = turn_result.next_step.new_agent - current_span.finish(reset_current=True) + if current_span: + AgentRunner._safe_finish(current_span, reset_current=True) + span_finished = True # this span is closed here current_span = None should_run_agent_start_hooks = True streamed_result._event_queue.put_nowait( AgentUpdatedStreamEvent(new_agent=current_agent) ) + + # After handoff, allow a new span to start on next loop + span_finished = False + elif isinstance(turn_result.next_step, NextStepFinalOutput): streamed_result._output_guardrails_task = asyncio.create_task( cls._run_output_guardrails( @@ -841,7 +1005,6 @@ async def _start_streaming( streamed_result.is_complete = True # Save the conversation to session if enabled - # Create a temporary RunResult for session saving temp_result = RunResult( input=streamed_result.input, new_items=streamed_result.new_items, @@ -856,12 +1019,40 @@ async def _start_streaming( session, starting_input, temp_result ) + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="completed") + ) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + elif isinstance(turn_result.next_step, NextStepRunAgain): + # No-op; continue loop for another turn pass + except AgentsException as exc: - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + # If a cancel was requested, normalize any exception as "cancelled" + if _cancel_token.is_cancelling(): + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent( + status="cancelled", + reason=getattr(_cancel_token, "reason", None), + ) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise + + # existing "failed" path + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="failed", reason=exc.__class__.__name__) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) exc.run_data = RunErrorDetails( input=streamed_result.input, new_items=streamed_result.new_items, @@ -872,25 +1063,70 @@ async def _start_streaming( output_guardrail_results=streamed_result.output_guardrail_results, ) raise + + except asyncio.CancelledError: + # Cooperative cancellation: treat as a normal terminal state + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent( + status="cancelled", + reason=getattr(_cancel_token, "reason", None), + ) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + except Exception as e: + # If a cancel was requested, normalize any exception as "cancelled" + if _cancel_token.is_cancelling(): + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent( + status="cancelled", + reason=getattr(_cancel_token, "reason", None), + ) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise + if current_span: _error_tracing.attach_error_to_span( current_span, - SpanError( - message="Error in agent run", - data={"error": str(e)}, - ), + SpanError(message="Error in agent run", data={"error": str(e)}), ) - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="failed", reason=e.__class__.__name__) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) raise streamed_result.is_complete = True + + except GeneratorExit: + # The coroutine is being garbage-collected/closed; avoid cross-context resets. + try: + if current_span and not span_finished: + AgentRunner._safe_finish(current_span, reset_current=False) + span_finished = True + if streamed_result.trace and not trace_finished: + AgentRunner._safe_finish(streamed_result.trace, reset_current=False) + trace_finished = True + finally: + # Respect generator close semantics. + raise + finally: - if current_span: - current_span.finish(reset_current=True) - if streamed_result.trace: - streamed_result.trace.finish(reset_current=True) + if current_span and not span_finished: + AgentRunner._safe_finish(current_span, reset_current=True) + if streamed_result.trace and not trace_finished: + AgentRunner._safe_finish(streamed_result.trace, reset_current=True) @classmethod async def _run_single_turn_streamed( @@ -931,10 +1167,18 @@ async def _run_single_turn_streamed( model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) final_response: ModelResponse | None = None + injected_during_turn = False input = ItemHelpers.input_to_new_input_list(streamed_result.input) input.extend([item.to_input_item() for item in streamed_result.new_items]) + # Consume any externally injected items before planning/model call + # Externally injected items live in streamed_result._inbox (a list of input items) + injected = getattr(streamed_result, "_inbox", None) + if injected: + input.extend(injected) + injected.clear() + # THIS IS THE RESOLVED CONFLICT BLOCK filtered = await cls._maybe_filter_model_input( agent=agent, @@ -964,6 +1208,12 @@ async def _run_single_turn_streamed( previous_response_id=previous_response_id, prompt=prompt_config, ): + # Cooperative cancel during streaming + cancel_token = getattr(streamed_result, "_cancel_token", None) + if cancel_token and cancel_token.is_cancelling(): + # Stop iterating; the model adapter should also close its stream cooperatively. + break + if isinstance(event, ResponseCompletedEvent): usage = ( Usage( @@ -986,6 +1236,31 @@ async def _run_single_turn_streamed( streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) + # Break early if new items were injected during this turn. + if injected and len(injected) > 0: + injected_during_turn = True + break + + if injected_during_turn and final_response is None: + return SingleStepResult( + original_input=streamed_result.input, + model_response=ModelResponse(output=[], usage=Usage(), response_id=None), + pre_step_items=streamed_result.new_items, + new_step_items=[], + next_step=NextStepRunAgain(), + ) + + # If cancelled during streaming, terminate cleanly + cancel_token = getattr(streamed_result, "_cancel_token", None) + if cancel_token and cancel_token.is_cancelling(): + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="cancelled", reason=cancel_token.reason) + ) + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise asyncio.CancelledError(cancel_token.reason or "Cancelled") + # Call hook just after the model response is finalized. if agent.hooks and final_response is not None: await agent.hooks.on_llm_end(context_wrapper, agent, final_response) @@ -1044,6 +1319,13 @@ async def _run_single_turn( input = ItemHelpers.input_to_new_input_list(original_input) input.extend([generated_item.to_input_item() for generated_item in generated_items]) + # Consume injected items (non-streamed runs) + # We stashed the inbox on the context wrapper to avoid changing all signatures. + inbox: list[TResponseInputItem] | None = getattr(context_wrapper, "_inbox", None) + if inbox: + input.extend(inbox) + inbox.clear() + new_response = await cls._get_new_response( agent, system_prompt, @@ -1251,6 +1533,11 @@ async def _get_new_response( previous_response_id: str | None, prompt_config: ResponsePromptParam | None, ) -> ModelResponse: + # Cooperative cancel before the model call (non-streamed) --- + cancel_token: _Cancellation | None = getattr(context_wrapper, "_cancel_token", None) + if cancel_token and cancel_token.is_cancelling(): + raise asyncio.CancelledError(cancel_token.reason or "Cancelled") + # Allow user to modify model input right before the call, if configured filtered = await cls._maybe_filter_model_input( agent=agent, @@ -1272,19 +1559,28 @@ async def _get_new_response( filtered.input, # Use filtered input ) - new_response = await model.get_response( - system_instructions=filtered.instructions, - input=filtered.input, - model_settings=model_settings, - tools=all_tools, - output_schema=output_schema, - handoffs=handoffs, - tracing=get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - prompt=prompt_config, - ) + async def _call_model() -> ModelResponse: + return await model.get_response( + system_instructions=filtered.instructions, + input=filtered.input, + model_settings=model_settings, + tools=all_tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + previous_response_id=previous_response_id, + prompt=prompt_config, + ) + + task = asyncio.create_task(_call_model()) + try: + new_response = await task + except asyncio.CancelledError: + # propagate; caller handles terminal state + raise + # If the agent has hooks, we need to call them after the LLM call if agent.hooks: await agent.hooks.on_llm_end(context_wrapper, agent, new_response) diff --git a/src/agents/stream_events.py b/src/agents/stream_events.py index a271e8acd..e4a90c603 100644 --- a/src/agents/stream_events.py +++ b/src/agents/stream_events.py @@ -56,6 +56,25 @@ class AgentUpdatedStreamEvent: type: Literal["agent_updated_stream_event"] = "agent_updated_stream_event" + # Terminal / status update event for the overall run -StreamEvent: TypeAlias = Union[RawResponsesStreamEvent, RunItemStreamEvent, AgentUpdatedStreamEvent] + +@dataclass +class RunUpdatedStreamEvent: + """High-level run status update (emitted on completion, failure, or cancellation).""" + + status: Literal["running", "completed", "failed", "cancelled"] = "running" + """Current run status.""" + reason: str | None = None + """Optional human-readable reason (e.g., cancellation reason).""" + type: Literal["run.updated"] = "run.updated" + """Event type identifier.""" + + +StreamEvent: TypeAlias = Union[ + RawResponsesStreamEvent, + RunItemStreamEvent, + AgentUpdatedStreamEvent, + RunUpdatedStreamEvent, +] """A streaming event from an agent.""" diff --git a/tests/test_run_lifecycle.py b/tests/test_run_lifecycle.py new file mode 100644 index 000000000..14fbfb687 --- /dev/null +++ b/tests/test_run_lifecycle.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import asyncio +import time +from collections.abc import AsyncIterator +from typing import Any, cast + +import pytest +from openai.types.responses import ResponseStreamEvent + +from agents.agent import Agent +from agents.agent_output import AgentOutputSchemaBase +from agents.handoffs import Handoff +from agents.items import TResponseInputItem +from agents.model_settings import ModelSettings +from agents.models.interface import Model, ModelTracing +from agents.run import RunConfig, Runner +from agents.stream_events import RunUpdatedStreamEvent +from agents.tool import Tool + +# Reuse the repo’s helper to build a FunctionTool correctly +from tests.test_responses import get_function_tool # <-- existing test helper + + +class MinimalAgent: + """Just enough surface for Runner.""" + + def __init__(self, model: Model, name: str = "test-agent"): + self.name = name + self.model = model + self.model_settings = ModelSettings() + self.output_type = None + self.hooks = None + self.handoffs: list[Handoff] = [] + self.reset_tool_choice = False + self.input_guardrails: list[Any] = [] + self.output_guardrails: list[Any] = [] + + async def get_system_prompt(self, _): + return None + + async def get_prompt(self, _): + return None + + async def get_all_tools(self, _): + return [] + + +class FakeModelNeverCompletes(Model): + """Never completes; yields generic events forever so we can cancel mid-stream.""" + + async def get_response(self, *a, **k): + raise NotImplementedError + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None, + prompt=None, + ) -> AsyncIterator[ResponseStreamEvent]: + while True: + await asyncio.sleep(0.02) + yield cast(ResponseStreamEvent, object()) + + +@pytest.mark.anyio +async def test_cancel_streamed_run_emits_cancelled_status(): + """When status events are enabled, cancel should emit run.updated(cancelled).""" + agent = MinimalAgent(model=FakeModelNeverCompletes()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="hello world", + run_config=RunConfig(model=agent.model), + max_turns=10, + ) + # Opt-in to status events for this test + result._emit_status_events = True + + seen_status: str | None = None + + async def consume(): + nonlocal seen_status + async for ev in result.stream_events(): + if isinstance(ev, RunUpdatedStreamEvent): + seen_status = ev.status + + consumer = asyncio.create_task(consume()) + await asyncio.sleep(0.08) + result.cancel("user-requested") + await consumer + + assert result.is_complete is True + assert seen_status == "cancelled" + + +@pytest.mark.anyio +async def test_default_flag_off_emits_no_status_event(): + """By default, no run.updated events should be emitted (back-compat).""" + agent = MinimalAgent(model=FakeModelNeverCompletes()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="x", + run_config=RunConfig(model=agent.model), + ) + # DO NOT set result._emit_status_events here + statuses: list[str] = [] + + async def consume(): + async for ev in result.stream_events(): + if isinstance(ev, RunUpdatedStreamEvent): + statuses.append(ev.status) + + task = asyncio.create_task(consume()) + await asyncio.sleep(0.05) + result.cancel("user") + await task + + assert statuses == [] # no run.updated by default + + +@pytest.mark.anyio +async def test_midstream_cancel_emits_cancelled_status_when_enabled(): + """Cancel while model is streaming yields cancelled when flag is on.""" + agent = MinimalAgent(model=FakeModelNeverCompletes()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="x", + run_config=RunConfig(model=agent.model), + ) + result._emit_status_events = True + statuses: list[str] = [] + + async def consume(): + async for ev in result.stream_events(): + if isinstance(ev, RunUpdatedStreamEvent): + statuses.append(ev.status) + + task = asyncio.create_task(consume()) + await asyncio.sleep(0.06) + result.cancel("user") + await task + + assert "cancelled" in statuses + + +class FakeModelSlowGet(Model): + async def get_response(self, *a, **k): + # simulate long compute so we can cancel + await asyncio.sleep(1.0) + + async def stream_response(self, *a, **k): + raise NotImplementedError + + +@pytest.mark.anyio +async def test_non_streamed_cancel_propagates_cancelled_error_or_returns_terminal_result(): + """Runner.run cancellation should terminate cleanly. + + We accept either a CancelledError or a terminal RunResult. + """ + agent = MinimalAgent(model=FakeModelSlowGet()) + + async def run_it(): + return await Runner.run( + cast(Agent[Any], agent), + input="y", + run_config=RunConfig(model=agent.model), + ) + + task = asyncio.create_task(run_it()) + await asyncio.sleep(0.05) + task.cancel() + + try: + result = await task + except asyncio.CancelledError: + # Current contract may propagate cancel; this is acceptable. + return + + # If your contract returns a terminal result on cancel, assert it here. + assert getattr(result, "final_output", None) is None + + +@pytest.mark.anyio +async def test_inject_is_consumed_on_next_turn(): + """ + Injected items should be included in a subsequent model turn input. + We capture the inputs passed into FakeModel each turn and assert presence. + """ + INJECT_TOKEN: TResponseInputItem = { + "role": "user", + "content": "INJECTED", + } # match message-style items + + class FakeModelCapture(Model): + def __init__(self): + self.inputs = [] # list[list[dict]] + + async def get_response(self, *a, **k): # non-stream path not used + raise NotImplementedError + + async def stream_response( + self, + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + tracing, + previous_response_id, + prompt=None, + ) -> AsyncIterator[ResponseStreamEvent]: + # Keep streaming so we never hit the "no final response" error. + while True: + # Record the input for this turn. + self.inputs.append(list(input)) + # Emit one event to complete a turn. + yield cast(ResponseStreamEvent, object()) + await asyncio.sleep(0.01) + + model = FakeModelCapture() + agent = MinimalAgent(model=model) + + result = Runner.run_streamed( + starting_agent=cast(Agent[Any], agent), + input="hello", + run_config=RunConfig(model=agent.model), + max_turns=6, + ) + + async def drive_and_inject(): + # Let at least one turn record baseline input + await asyncio.sleep(0.05) + # Inject so a future turn sees it + result.inject([INJECT_TOKEN]) + # Give time for a couple more turns to run + await asyncio.sleep(0.12) + result.cancel("done") + + consumer = asyncio.create_task(drive_and_inject()) + async for _ in result.stream_events(): + pass + await consumer + + # We should have recorded ≥2 turns + assert len(model.inputs) >= 2 + + # Assert the injected message appears in ANY turn after injection time + flattened_after_injection = [item for turn in model.inputs[1:] for item in turn] + assert any( + isinstance(item, dict) and item.get("role") == "user" and item.get("content") == "INJECTED" + for item in flattened_after_injection + ), f"Injected item not present after injection; captured={model.inputs}" + + +class FakeModelTriggersTool(Model): + """ + Emits continuous events so we can cancel while a function tool is (hypothetically) running. + Note: This is a timing smoke test. For a full tool-call path test, emit tool-call outputs. + """ + + async def get_response(self, *a, **k): + raise NotImplementedError + + async def stream_response( + self, + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + tracing, + previous_response_id, + prompt=None, + ) -> AsyncIterator[ResponseStreamEvent]: + while True: + await asyncio.sleep(0.02) + yield cast(ResponseStreamEvent, object()) + + +class AgentWithTool(MinimalAgent): + def __init__(self, model: Model, tool: Tool): + super().__init__(model) + self._tool = tool + + async def get_all_tools(self, _): + return [self._tool] + + +@pytest.mark.anyio +async def test_function_tool_cancels_promptly(): + # Build the tool using the repo helper (it doesn't take a handler argument) + tool = get_function_tool("long", "done") + + agent = AgentWithTool(FakeModelTriggersTool(), tool) + + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="trigger tool", + run_config=RunConfig(model=agent.model), + ) + start = time.perf_counter() + await asyncio.sleep(0.05) # let some activity happen + result.cancel("user") + + # Drain stream; ensure no hang + async for _ in result.stream_events(): + pass + + elapsed = time.perf_counter() - start + # Expect prompt cancellation (well under 1s) + assert elapsed < 0.4