From 649d188088ddd337eacc955d893931aca51678f3 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 11 Jul 2025 12:10:08 -0400 Subject: [PATCH 1/2] [1/n] Break Agent into AgentBase+Agent --- src/agents/__init__.py | 3 +- src/agents/agent.py | 84 ++++++++++++++++++++++++++----------- src/agents/lifecycle.py | 43 +++++++++++-------- src/agents/mcp/server.py | 10 ++--- src/agents/mcp/util.py | 11 +++-- src/agents/tool.py | 10 ++--- tests/test_function_tool.py | 11 ++++- 7 files changed, 111 insertions(+), 61 deletions(-) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 1bcb06e9f..7de17efdb 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -5,7 +5,7 @@ from openai import AsyncOpenAI from . import _config -from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult +from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult from .agent_output import AgentOutputSchema, AgentOutputSchemaBase from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( @@ -161,6 +161,7 @@ def enable_verbose_stdout_logging(): __all__ = [ "Agent", + "AgentBase", "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", "Runner", diff --git a/src/agents/agent.py b/src/agents/agent.py index 2f77f5f2a..e418b67da 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -67,7 +67,63 @@ class MCPConfig(TypedDict): @dataclass -class Agent(Generic[TContext]): +class AgentBase(Generic[TContext]): + """Base class for `Agent` and `RealtimeAgent`.""" + + name: str + """The name of the agent.""" + + handoff_description: str | None = None + """A description of the agent. This is used when the agent is used as a handoff, so that an + LLM knows what it does and when to invoke it. + """ + + tools: list[Tool] = field(default_factory=list) + """A list of tools that the agent can use.""" + + mcp_servers: list[MCPServer] = field(default_factory=list) + """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that + the agent can use. Every time the agent runs, it will include tools from these servers in the + list of available tools. + + NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call + `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no + longer needed. + """ + + mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) + """Configuration for MCP servers.""" + + async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: + """Fetches the available tools from the MCP servers.""" + convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, convert_schemas_to_strict, run_context, self + ) + + async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: + """All agent tools, including MCP tools and function tools.""" + mcp_tools = await self.get_mcp_tools(run_context) + + async def _check_tool_enabled(tool: Tool) -> bool: + if not isinstance(tool, FunctionTool): + return True + + attr = tool.is_enabled + if isinstance(attr, bool): + return attr + res = attr(run_context, self) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools)) + enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok] + return [*mcp_tools, *enabled] + + +@dataclass +class Agent(AgentBase, Generic[TContext]): """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more. We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In @@ -76,10 +132,9 @@ class Agent(Generic[TContext]): Agents are generic on the context type. The context is a (mutable) object you create. It is passed to tool functions, handoffs, guardrails, etc. - """ - name: str - """The name of the agent.""" + See `AgentBase` for base parameters that are shared with `RealtimeAgent`s. + """ instructions: ( str @@ -103,11 +158,6 @@ class Agent(Generic[TContext]): usable with OpenAI models, using the Responses API. """ - handoff_description: str | None = None - """A description of the agent. This is used when the agent is used as a handoff, so that an - LLM knows what it does and when to invoke it. - """ - handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list) """Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs, and the agent can choose to delegate to them if relevant. Allows for separation of concerns and @@ -125,22 +175,6 @@ class Agent(Generic[TContext]): """Configures model-specific tuning parameters (e.g. temperature, top_p). """ - tools: list[Tool] = field(default_factory=list) - """A list of tools that the agent can use.""" - - mcp_servers: list[MCPServer] = field(default_factory=list) - """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that - the agent can use. Every time the agent runs, it will include tools from these servers in the - list of available tools. - - NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call - `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no - longer needed. - """ - - mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) - """Configuration for MCP servers.""" - input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list) """A list of checks that run in parallel to the agent's execution, before generating a response. Runs only if the agent is the first agent in the chain. diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 8643248b1..2cce496c8 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -1,25 +1,27 @@ from typing import Any, Generic -from .agent import Agent +from typing_extensions import TypeVar + +from .agent import Agent, AgentBase from .run_context import RunContextWrapper, TContext from .tool import Tool +TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase) + -class RunHooks(Generic[TContext]): +class RunHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events in an agent run. Subclass and override the methods you need. """ - async def on_agent_start( - self, context: RunContextWrapper[TContext], agent: Agent[TContext] - ) -> None: + async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None: """Called before the agent is invoked. Called each time the current agent changes.""" pass async def on_agent_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, output: Any, ) -> None: """Called when the agent produces a final output.""" @@ -28,8 +30,8 @@ async def on_agent_end( async def on_handoff( self, context: RunContextWrapper[TContext], - from_agent: Agent[TContext], - to_agent: Agent[TContext], + from_agent: TAgent, + to_agent: TAgent, ) -> None: """Called when a handoff occurs.""" pass @@ -37,7 +39,7 @@ async def on_handoff( async def on_tool_start( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, ) -> None: """Called before a tool is invoked.""" @@ -46,7 +48,7 @@ async def on_tool_start( async def on_tool_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, result: str, ) -> None: @@ -54,14 +56,14 @@ async def on_tool_end( pass -class AgentHooks(Generic[TContext]): +class AgentHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events for a specific agent. You can set this on `agent.hooks` to receive events for that specific agent. Subclass and override the methods you need. """ - async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None: + async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None: """Called before the agent is invoked. Called each time the running agent is changed to this agent.""" pass @@ -69,7 +71,7 @@ async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TCon async def on_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, output: Any, ) -> None: """Called when the agent produces a final output.""" @@ -78,8 +80,8 @@ async def on_end( async def on_handoff( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], - source: Agent[TContext], + agent: TAgent, + source: TAgent, ) -> None: """Called when the agent is being handed off to. The `source` is the agent that is handing off to this agent.""" @@ -88,7 +90,7 @@ async def on_handoff( async def on_tool_start( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, ) -> None: """Called before a tool is invoked.""" @@ -97,9 +99,16 @@ async def on_tool_start( async def on_tool_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, result: str, ) -> None: """Called after a tool is invoked.""" pass + + +RunHooks = RunHooksBase[TContext, Agent] +"""Run hooks when using `Agent`.""" + +AgentHooks = AgentHooksBase[TContext, Agent] +"""Agent hooks for `Agent`s.""" diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 4fd606e34..91a9274fc 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -22,7 +22,7 @@ from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic if TYPE_CHECKING: - from ..agent import Agent + from ..agent import AgentBase class MCPServer(abc.ABC): @@ -53,7 +53,7 @@ async def cleanup(self): async def list_tools( self, run_context: RunContextWrapper[Any] | None = None, - agent: Agent[Any] | None = None, + agent: AgentBase | None = None, ) -> list[MCPTool]: """List the tools available on the server.""" pass @@ -117,7 +117,7 @@ async def _apply_tool_filter( self, tools: list[MCPTool], run_context: RunContextWrapper[Any], - agent: Agent[Any], + agent: AgentBase, ) -> list[MCPTool]: """Apply the tool filter to the list of tools.""" if self.tool_filter is None: @@ -153,7 +153,7 @@ async def _apply_dynamic_tool_filter( self, tools: list[MCPTool], run_context: RunContextWrapper[Any], - agent: Agent[Any], + agent: AgentBase, ) -> list[MCPTool]: """Apply dynamic tool filtering using a callable filter function.""" @@ -244,7 +244,7 @@ async def connect(self): async def list_tools( self, run_context: RunContextWrapper[Any] | None = None, - agent: Agent[Any] | None = None, + agent: AgentBase | None = None, ) -> list[MCPTool]: """List the tools available on the server.""" if not self.session: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 48da9f841..18cf4440a 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -5,12 +5,11 @@ from typing_extensions import NotRequired, TypedDict -from agents.strict_schema import ensure_strict_json_schema - from .. import _debug from ..exceptions import AgentsException, ModelBehaviorError, UserError from ..logger import logger from ..run_context import RunContextWrapper +from ..strict_schema import ensure_strict_json_schema from ..tool import FunctionTool, Tool from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span from ..util._types import MaybeAwaitable @@ -18,7 +17,7 @@ if TYPE_CHECKING: from mcp.types import Tool as MCPTool - from ..agent import Agent + from ..agent import AgentBase from .server import MCPServer @@ -29,7 +28,7 @@ class ToolFilterContext: run_context: RunContextWrapper[Any] """The current run context.""" - agent: "Agent[Any]" + agent: "AgentBase" """The agent that is requesting the tool list.""" server_name: str @@ -100,7 +99,7 @@ async def get_all_function_tools( servers: list["MCPServer"], convert_schemas_to_strict: bool, run_context: RunContextWrapper[Any], - agent: "Agent[Any]", + agent: "AgentBase", ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] @@ -126,7 +125,7 @@ async def get_function_tools( server: "MCPServer", convert_schemas_to_strict: bool, run_context: RunContextWrapper[Any], - agent: "Agent[Any]", + agent: "AgentBase", ) -> list[Tool]: """Get all function tools from a single MCP server.""" diff --git a/src/agents/tool.py b/src/agents/tool.py index f2861b936..0faf09846 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -30,7 +30,7 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: - from .agent import Agent + from .agent import Agent, AgentBase ToolParams = ParamSpec("ToolParams") @@ -87,7 +87,7 @@ class FunctionTool: """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True """Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool based on your context/state.""" @@ -300,7 +300,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -315,7 +315,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -330,7 +330,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index b232bf75e..9b02a2c42 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -5,7 +5,14 @@ from pydantic import BaseModel from typing_extensions import TypedDict -from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool +from agents import ( + Agent, + AgentBase, + FunctionTool, + ModelBehaviorError, + RunContextWrapper, + function_tool, +) from agents.tool import default_tool_error_function from agents.tool_context import ToolContext @@ -268,7 +275,7 @@ async def test_is_enabled_bool_and_callable(): def disabled_tool(): return "nope" - async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool: + async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: AgentBase) -> bool: return ctx.context.enable_tools @function_tool(is_enabled=cond_enabled) From 231052d71537688c55690ba9f863569b5427b637 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 11 Jul 2025 12:10:09 -0400 Subject: [PATCH 2/2] [2/n] Introduce RealtimeAgent --- src/agents/realtime/agent.py | 80 ++++++++++++++++++++++++++++++++++++ tests/realtime/test_agent.py | 25 +++++++++++ 2 files changed, 105 insertions(+) create mode 100644 src/agents/realtime/agent.py create mode 100644 tests/realtime/test_agent.py diff --git a/src/agents/realtime/agent.py b/src/agents/realtime/agent.py new file mode 100644 index 000000000..9bbed8cb4 --- /dev/null +++ b/src/agents/realtime/agent.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import dataclasses +import inspect +from collections.abc import Awaitable +from dataclasses import dataclass +from typing import Any, Callable, Generic, cast + +from ..agent import AgentBase +from ..lifecycle import AgentHooksBase, RunHooksBase +from ..logger import logger +from ..run_context import RunContextWrapper, TContext +from ..util._types import MaybeAwaitable + +RealtimeAgentHooks = AgentHooksBase[TContext, "RealtimeAgent[TContext]"] +"""Agent hooks for `RealtimeAgent`s.""" + +RealtimeRunHooks = RunHooksBase[TContext, "RealtimeAgent[TContext]"] +"""Run hooks for `RealtimeAgent`s.""" + + +@dataclass +class RealtimeAgent(AgentBase, Generic[TContext]): + """A specialized agent instance that is meant to be used within a `RealtimeSession` to build + voice agents. Due to the nature of this agent, some configuration options are not supported + that are supported by regular `Agent` instances. For example: + - `model` choice is not supported, as all RealtimeAgents will be handled by the same model + within a `RealtimeSession`. + - `modelSettings` is not supported, as all RealtimeAgents will be handled by the same model + within a `RealtimeSession`. + - `outputType` is not supported, as RealtimeAgents do not support structured outputs. + - `toolUseBehavior` is not supported, as all RealtimeAgents will be handled by the same model + within a `RealtimeSession`. + - `voice` can be configured on an `Agent` level; however, it cannot be changed after the first + agent within a `RealtimeSession` has spoken. + + See `AgentBase` for base parameters that are shared with `Agent`s. + """ + + instructions: ( + str + | Callable[ + [RunContextWrapper[TContext], RealtimeAgent[TContext]], + MaybeAwaitable[str], + ] + | None + ) = None + """The instructions for the agent. Will be used as the "system prompt" when this agent is + invoked. Describes what the agent should do, and how it responds. + + Can either be a string, or a function that dynamically generates instructions for the agent. If + you provide a function, it will be called with the context and the agent instance. It must + return a string. + """ + + hooks: RealtimeAgentHooks | None = None + """A class that receives callbacks on various lifecycle events for this agent. + """ + + def clone(self, **kwargs: Any) -> RealtimeAgent[TContext]: + """Make a copy of the agent, with the given arguments changed. For example, you could do: + ``` + new_agent = agent.clone(instructions="New instructions") + ``` + """ + return dataclasses.replace(self, **kwargs) + + async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None: + """Get the system prompt for the agent.""" + if isinstance(self.instructions, str): + return self.instructions + elif callable(self.instructions): + if inspect.iscoroutinefunction(self.instructions): + return await cast(Awaitable[str], self.instructions(run_context, self)) + else: + return cast(str, self.instructions(run_context, self)) + elif self.instructions is not None: + logger.error(f"Instructions must be a string or a function, got {self.instructions}") + + return None diff --git a/tests/realtime/test_agent.py b/tests/realtime/test_agent.py new file mode 100644 index 000000000..aae8bc47c --- /dev/null +++ b/tests/realtime/test_agent.py @@ -0,0 +1,25 @@ +import pytest + +from agents import RunContextWrapper +from agents.realtime.agent import RealtimeAgent + + +def test_can_initialize_realtime_agent(): + agent = RealtimeAgent(name="test", instructions="Hello") + assert agent.name == "test" + assert agent.instructions == "Hello" + + +@pytest.mark.asyncio +async def test_dynamic_instructions(): + agent = RealtimeAgent(name="test") + assert agent.instructions is None + + def _instructions(ctx, agt) -> str: + assert ctx.context is None + assert agt == agent + return "Dynamic" + + agent = RealtimeAgent(name="test", instructions=_instructions) + instructions = await agent.get_system_prompt(RunContextWrapper(context=None)) + assert instructions == "Dynamic"