1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
+ import contextlib
4
5
import dataclasses
5
6
import inspect
6
7
from collections .abc import Awaitable
@@ -226,6 +227,29 @@ def get_model_tracing_impl(
226
227
return ModelTracing .ENABLED_WITHOUT_DATA
227
228
228
229
230
+ # Helpers for cancellable tool execution
231
+
232
+
233
+ async def _await_cancellable (awaitable ):
234
+ """Await an awaitable in its own task so CancelledError interrupts promptly."""
235
+ task = asyncio .create_task (awaitable )
236
+ try :
237
+ return await task
238
+ except asyncio .CancelledError :
239
+ # propagate so run.py can handle terminal cancel
240
+ raise
241
+
242
+
243
+ def _maybe_call_cancel_hook (tool_obj ) -> None :
244
+ """Best-effort: call a cancel/terminate hook on the tool if present."""
245
+ for name in ("cancel" , "terminate" , "stop" ):
246
+ cb = getattr (tool_obj , name , None )
247
+ if callable (cb ):
248
+ with contextlib .suppress (Exception ):
249
+ cb ()
250
+ break
251
+
252
+
229
253
class RunImpl :
230
254
@classmethod
231
255
async def execute_tools_and_side_effects (
@@ -556,16 +580,24 @@ async def run_single_tool(
556
580
if config .trace_include_sensitive_data :
557
581
span_fn .span_data .input = tool_call .arguments
558
582
try :
559
- _ , _ , result = await asyncio .gather (
583
+ # run start hooks first (don’t tie them to the cancellable task)
584
+ await asyncio .gather (
560
585
hooks .on_tool_start (tool_context , agent , func_tool ),
561
586
(
562
587
agent .hooks .on_tool_start (tool_context , agent , func_tool )
563
588
if agent .hooks
564
589
else _coro .noop_coroutine ()
565
590
),
566
- func_tool .on_invoke_tool (tool_context , tool_call .arguments ),
567
591
)
568
592
593
+ try :
594
+ result = await _await_cancellable (
595
+ func_tool .on_invoke_tool (tool_context , tool_call .arguments )
596
+ )
597
+ except asyncio .CancelledError :
598
+ _maybe_call_cancel_hook (func_tool )
599
+ raise
600
+
569
601
await asyncio .gather (
570
602
hooks .on_tool_end (tool_context , agent , func_tool , result ),
571
603
(
@@ -574,6 +606,7 @@ async def run_single_tool(
574
606
else _coro .noop_coroutine ()
575
607
),
576
608
)
609
+
577
610
except Exception as e :
578
611
_error_tracing .attach_error_to_current_span (
579
612
SpanError (
@@ -644,7 +677,6 @@ async def execute_computer_actions(
644
677
config : RunConfig ,
645
678
) -> list [RunItem ]:
646
679
results : list [RunItem ] = []
647
- # Need to run these serially, because each action can affect the computer state
648
680
for action in actions :
649
681
acknowledged : list [ComputerCallOutputAcknowledgedSafetyCheck ] | None = None
650
682
if action .tool_call .pending_safety_checks and action .computer_tool .on_safety_check :
@@ -661,24 +693,28 @@ async def execute_computer_actions(
661
693
if ack :
662
694
acknowledged .append (
663
695
ComputerCallOutputAcknowledgedSafetyCheck (
664
- id = check .id ,
665
- code = check .code ,
666
- message = check .message ,
696
+ id = check .id , code = check .code , message = check .message
667
697
)
668
698
)
669
699
else :
670
700
raise UserError ("Computer tool safety check was not acknowledged" )
671
701
672
- results .append (
673
- await ComputerAction .execute (
674
- agent = agent ,
675
- action = action ,
676
- hooks = hooks ,
677
- context_wrapper = context_wrapper ,
678
- config = config ,
679
- acknowledged_safety_checks = acknowledged ,
702
+ try :
703
+ item = await _await_cancellable (
704
+ ComputerAction .execute (
705
+ agent = agent ,
706
+ action = action ,
707
+ hooks = hooks ,
708
+ context_wrapper = context_wrapper ,
709
+ config = config ,
710
+ acknowledged_safety_checks = acknowledged ,
711
+ )
680
712
)
681
- )
713
+ except asyncio .CancelledError :
714
+ _maybe_call_cancel_hook (action .computer_tool )
715
+ raise
716
+
717
+ results .append (item )
682
718
683
719
return results
684
720
@@ -1052,16 +1088,23 @@ async def execute(
1052
1088
else cls ._get_screenshot_sync (action .computer_tool .computer , action .tool_call )
1053
1089
)
1054
1090
1055
- _ , _ , output = await asyncio .gather (
1091
+ # start hooks first
1092
+ await asyncio .gather (
1056
1093
hooks .on_tool_start (context_wrapper , agent , action .computer_tool ),
1057
1094
(
1058
1095
agent .hooks .on_tool_start (context_wrapper , agent , action .computer_tool )
1059
1096
if agent .hooks
1060
1097
else _coro .noop_coroutine ()
1061
1098
),
1062
- output_func ,
1063
1099
)
1064
-
1100
+ # run the action (screenshot/etc) in a cancellable task
1101
+ try :
1102
+ output = await _await_cancellable (output_func )
1103
+ except asyncio .CancelledError :
1104
+ _maybe_call_cancel_hook (action .computer_tool )
1105
+ raise
1106
+
1107
+ # end hooks
1065
1108
await asyncio .gather (
1066
1109
hooks .on_tool_end (context_wrapper , agent , action .computer_tool , output ),
1067
1110
(
@@ -1169,10 +1212,20 @@ async def execute(
1169
1212
data = call .tool_call ,
1170
1213
)
1171
1214
output = call .local_shell_tool .executor (request )
1172
- if inspect .isawaitable (output ):
1173
- result = await output
1174
- else :
1175
- result = output
1215
+ try :
1216
+ if inspect .isawaitable (output ):
1217
+ result = await _await_cancellable (output )
1218
+ else :
1219
+ # If executor returns a sync result, just use it (can’t cancel mid-call)
1220
+ result = output
1221
+ except asyncio .CancelledError :
1222
+ # Best-effort: if the executor or tool exposes a cancel/terminate, call it
1223
+ _maybe_call_cancel_hook (call .local_shell_tool )
1224
+ # If your executor returns a proc handle (common pattern), adddress it here if needed:
1225
+ # with contextlib.suppress(Exception):
1226
+ # proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0)
1227
+ # proc.kill()
1228
+ raise
1176
1229
1177
1230
await asyncio .gather (
1178
1231
hooks .on_tool_end (context_wrapper , agent , call .local_shell_tool , result ),
@@ -1185,7 +1238,7 @@ async def execute(
1185
1238
1186
1239
return ToolCallOutputItem (
1187
1240
agent = agent ,
1188
- output = output ,
1241
+ output = result ,
1189
1242
raw_item = {
1190
1243
"type" : "local_shell_call_output" ,
1191
1244
"id" : call .tool_call .call_id ,
0 commit comments