Skip to content

Commit fac7da7

Browse files
committed
feat(run): add lifecycle interrupt + inject, cancel-aware tools, safer tracing
1 parent 71be678 commit fac7da7

File tree

6 files changed

+823
-72
lines changed

6 files changed

+823
-72
lines changed

src/agents/_run_impl.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
import dataclasses
56
import inspect
67
from collections.abc import Awaitable
@@ -226,6 +227,29 @@ def get_model_tracing_impl(
226227
return ModelTracing.ENABLED_WITHOUT_DATA
227228

228229

230+
# Helpers for cancellable tool execution
231+
232+
233+
async def _await_cancellable(awaitable):
234+
"""Await an awaitable in its own task so CancelledError interrupts promptly."""
235+
task = asyncio.create_task(awaitable)
236+
try:
237+
return await task
238+
except asyncio.CancelledError:
239+
# propagate so run.py can handle terminal cancel
240+
raise
241+
242+
243+
def _maybe_call_cancel_hook(tool_obj) -> None:
244+
"""Best-effort: call a cancel/terminate hook on the tool if present."""
245+
for name in ("cancel", "terminate", "stop"):
246+
cb = getattr(tool_obj, name, None)
247+
if callable(cb):
248+
with contextlib.suppress(Exception):
249+
cb()
250+
break
251+
252+
229253
class RunImpl:
230254
@classmethod
231255
async def execute_tools_and_side_effects(
@@ -556,16 +580,24 @@ async def run_single_tool(
556580
if config.trace_include_sensitive_data:
557581
span_fn.span_data.input = tool_call.arguments
558582
try:
559-
_, _, result = await asyncio.gather(
583+
# run start hooks first (don’t tie them to the cancellable task)
584+
await asyncio.gather(
560585
hooks.on_tool_start(tool_context, agent, func_tool),
561586
(
562587
agent.hooks.on_tool_start(tool_context, agent, func_tool)
563588
if agent.hooks
564589
else _coro.noop_coroutine()
565590
),
566-
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
567591
)
568592

593+
try:
594+
result = await _await_cancellable(
595+
func_tool.on_invoke_tool(tool_context, tool_call.arguments)
596+
)
597+
except asyncio.CancelledError:
598+
_maybe_call_cancel_hook(func_tool)
599+
raise
600+
569601
await asyncio.gather(
570602
hooks.on_tool_end(tool_context, agent, func_tool, result),
571603
(
@@ -574,6 +606,7 @@ async def run_single_tool(
574606
else _coro.noop_coroutine()
575607
),
576608
)
609+
577610
except Exception as e:
578611
_error_tracing.attach_error_to_current_span(
579612
SpanError(
@@ -644,7 +677,6 @@ async def execute_computer_actions(
644677
config: RunConfig,
645678
) -> list[RunItem]:
646679
results: list[RunItem] = []
647-
# Need to run these serially, because each action can affect the computer state
648680
for action in actions:
649681
acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
650682
if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
@@ -661,24 +693,28 @@ async def execute_computer_actions(
661693
if ack:
662694
acknowledged.append(
663695
ComputerCallOutputAcknowledgedSafetyCheck(
664-
id=check.id,
665-
code=check.code,
666-
message=check.message,
696+
id=check.id, code=check.code, message=check.message
667697
)
668698
)
669699
else:
670700
raise UserError("Computer tool safety check was not acknowledged")
671701

672-
results.append(
673-
await ComputerAction.execute(
674-
agent=agent,
675-
action=action,
676-
hooks=hooks,
677-
context_wrapper=context_wrapper,
678-
config=config,
679-
acknowledged_safety_checks=acknowledged,
702+
try:
703+
item = await _await_cancellable(
704+
ComputerAction.execute(
705+
agent=agent,
706+
action=action,
707+
hooks=hooks,
708+
context_wrapper=context_wrapper,
709+
config=config,
710+
acknowledged_safety_checks=acknowledged,
711+
)
680712
)
681-
)
713+
except asyncio.CancelledError:
714+
_maybe_call_cancel_hook(action.computer_tool)
715+
raise
716+
717+
results.append(item)
682718

683719
return results
684720

@@ -1052,16 +1088,23 @@ async def execute(
10521088
else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call)
10531089
)
10541090

1055-
_, _, output = await asyncio.gather(
1091+
# start hooks first
1092+
await asyncio.gather(
10561093
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
10571094
(
10581095
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
10591096
if agent.hooks
10601097
else _coro.noop_coroutine()
10611098
),
1062-
output_func,
10631099
)
1064-
1100+
# run the action (screenshot/etc) in a cancellable task
1101+
try:
1102+
output = await _await_cancellable(output_func)
1103+
except asyncio.CancelledError:
1104+
_maybe_call_cancel_hook(action.computer_tool)
1105+
raise
1106+
1107+
# end hooks
10651108
await asyncio.gather(
10661109
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
10671110
(
@@ -1169,10 +1212,20 @@ async def execute(
11691212
data=call.tool_call,
11701213
)
11711214
output = call.local_shell_tool.executor(request)
1172-
if inspect.isawaitable(output):
1173-
result = await output
1174-
else:
1175-
result = output
1215+
try:
1216+
if inspect.isawaitable(output):
1217+
result = await _await_cancellable(output)
1218+
else:
1219+
# If executor returns a sync result, just use it (can’t cancel mid-call)
1220+
result = output
1221+
except asyncio.CancelledError:
1222+
# Best-effort: if the executor or tool exposes a cancel/terminate, call it
1223+
_maybe_call_cancel_hook(call.local_shell_tool)
1224+
# If your executor returns a proc handle (common pattern), adddress it here if needed:
1225+
# with contextlib.suppress(Exception):
1226+
# proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0)
1227+
# proc.kill()
1228+
raise
11761229

11771230
await asyncio.gather(
11781231
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
@@ -1185,7 +1238,7 @@ async def execute(
11851238

11861239
return ToolCallOutputItem(
11871240
agent=agent,
1188-
output=output,
1241+
output=result,
11891242
raw_item={
11901243
"type": "local_shell_call_output",
11911244
"id": call.tool_call.call_id,

src/agents/models/openai_responses.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
from collections.abc import AsyncIterator
56
from dataclasses import dataclass
@@ -172,15 +173,30 @@ async def stream_response(
172173

173174
final_response: Response | None = None
174175

175-
async for chunk in stream:
176-
if isinstance(chunk, ResponseCompletedEvent):
177-
final_response = chunk.response
178-
yield chunk
176+
try:
177+
async for chunk in stream: # ensure type checkers relax here
178+
if isinstance(chunk, ResponseCompletedEvent):
179+
final_response = chunk.response
180+
yield chunk
181+
except asyncio.CancelledError:
182+
# Cooperative cancel: ensure the HTTP stream is closed, then propagate
183+
try:
184+
await stream.close()
185+
except Exception:
186+
pass
187+
raise
188+
finally:
189+
# Always close the stream if the async iterator exits (normal or error)
190+
try:
191+
await stream.close()
192+
except Exception:
193+
pass
179194

180195
if final_response and tracing.include_data():
181196
span_response.span_data.response = final_response
182197
span_response.span_data.input = input
183198

199+
184200
except Exception as e:
185201
span_response.set_error(
186202
SpanError(

src/agents/result.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
import asyncio
5+
import contextlib
56
from collections.abc import AsyncIterator
67
from dataclasses import dataclass, field
78
from typing import TYPE_CHECKING, Any, cast
@@ -143,6 +144,12 @@ class RunResultStreaming(RunResultBase):
143144
is_complete: bool = False
144145
"""Whether the agent has finished running."""
145146

147+
_emit_status_events: bool = False
148+
"""Whether to emit RunUpdatedStreamEvent status updates.
149+
150+
Defaults to False for backward compatibility.
151+
"""
152+
146153
# Queues that the background run_loop writes to
147154
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field(
148155
default_factory=asyncio.Queue, repr=False
@@ -164,17 +171,45 @@ def last_agent(self) -> Agent[Any]:
164171
"""
165172
return self.current_agent
166173

167-
def cancel(self) -> None:
168-
"""Cancels the streaming run, stopping all background tasks and marking the run as
169-
complete."""
170-
self._cleanup_tasks() # Cancel all running tasks
171-
self.is_complete = True # Mark the run as complete to stop event streaming
174+
def cancel(self, reason: str | None = None) -> None:
175+
# 1) Signal cooperative cancel to the runner
176+
active = getattr(self, "_active_run", None)
177+
if active:
178+
with contextlib.suppress(Exception):
179+
active.cancel(reason)
180+
# 2) Do NOT cancel the background task; let the loop unwind cooperatively
181+
# task = getattr(self, "_run_impl_task", None)
182+
# if task and not task.done():
183+
# with contextlib.suppress(Exception):
184+
# task.cancel()
185+
186+
# 4) Mark complete; flushing only when status events are disabled
187+
self.is_complete = True
188+
if not getattr(self, "_emit_status_events", False):
189+
with contextlib.suppress(Exception):
190+
while not self._event_queue.empty():
191+
self._event_queue.get_nowait()
192+
self._event_queue.task_done()
193+
with contextlib.suppress(Exception):
194+
while not self._input_guardrail_queue.empty():
195+
self._input_guardrail_queue.get_nowait()
196+
self._input_guardrail_queue.task_done()
197+
198+
def inject(self, items: list[TResponseInputItem]) -> None:
199+
"""
200+
Inject new input items mid-run. They will be consumed at the start of the next step.
201+
"""
202+
active = getattr(self, "_active_run", None)
203+
if active is not None:
204+
try:
205+
active.inject(items)
206+
except Exception:
207+
pass
172208

173-
# Optionally, clear the event queue to prevent processing stale events
174-
while not self._event_queue.empty():
175-
self._event_queue.get_nowait()
176-
while not self._input_guardrail_queue.empty():
177-
self._input_guardrail_queue.get_nowait()
209+
@property
210+
def active_run(self):
211+
"""Access the underlying ActiveRun handle (may be None early in startup)."""
212+
return getattr(self, "_active_run", None)
178213

179214
async def stream_events(self) -> AsyncIterator[StreamEvent]:
180215
"""Stream deltas for new items as they are generated. We're using the types from the
@@ -243,21 +278,33 @@ def _check_errors(self):
243278
# Check the tasks for any exceptions
244279
if self._run_impl_task and self._run_impl_task.done():
245280
run_impl_exc = self._run_impl_task.exception()
246-
if run_impl_exc and isinstance(run_impl_exc, Exception):
281+
if (
282+
run_impl_exc
283+
and isinstance(run_impl_exc, Exception)
284+
and not isinstance(run_impl_exc, asyncio.CancelledError)
285+
):
247286
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
248287
run_impl_exc.run_data = self._create_error_details()
249288
self._stored_exception = run_impl_exc
250289

251290
if self._input_guardrails_task and self._input_guardrails_task.done():
252291
in_guard_exc = self._input_guardrails_task.exception()
253-
if in_guard_exc and isinstance(in_guard_exc, Exception):
292+
if (
293+
in_guard_exc
294+
and isinstance(in_guard_exc, Exception)
295+
and not isinstance(in_guard_exc, asyncio.CancelledError)
296+
):
254297
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
255298
in_guard_exc.run_data = self._create_error_details()
256299
self._stored_exception = in_guard_exc
257300

258301
if self._output_guardrails_task and self._output_guardrails_task.done():
259302
out_guard_exc = self._output_guardrails_task.exception()
260-
if out_guard_exc and isinstance(out_guard_exc, Exception):
303+
if (
304+
out_guard_exc
305+
and isinstance(out_guard_exc, Exception)
306+
and not isinstance(out_guard_exc, asyncio.CancelledError)
307+
):
261308
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
262309
out_guard_exc.run_data = self._create_error_details()
263310
self._stored_exception = out_guard_exc

0 commit comments

Comments
 (0)