diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index d75b0c4e0..0acb1345a 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -3,10 +3,11 @@ import abc import asyncio import inspect +from collections.abc import Awaitable from contextlib import AbstractAsyncContextManager, AsyncExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client @@ -21,6 +22,8 @@ from ..run_context import RunContextWrapper from .util import ToolFilter, ToolFilterContext, ToolFilterStatic +T = TypeVar("T") + if TYPE_CHECKING: from ..agent import AgentBase @@ -98,6 +101,8 @@ def __init__( client_session_timeout_seconds: float | None, tool_filter: ToolFilter = None, use_structured_content: bool = False, + max_retry_attempts: int = 0, + retry_backoff_seconds_base: float = 1.0, ): """ Args: @@ -115,6 +120,10 @@ def __init__( include the structured content in the `tool_result.content`, and using it by default will cause duplicate content. You can set this to True if you know the server will not duplicate the structured content in the `tool_result.content`. + max_retry_attempts: Number of times to retry failed list_tools/call_tool calls. + Defaults to no retries. + retry_backoff_seconds_base: The base delay, in seconds, used for exponential + backoff between retries. """ super().__init__(use_structured_content=use_structured_content) self.session: ClientSession | None = None @@ -124,6 +133,8 @@ def __init__( self.server_initialize_result: InitializeResult | None = None self.client_session_timeout_seconds = client_session_timeout_seconds + self.max_retry_attempts = max_retry_attempts + self.retry_backoff_seconds_base = retry_backoff_seconds_base # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True @@ -233,6 +244,18 @@ def invalidate_tools_cache(self): """Invalidate the tools cache.""" self._cache_dirty = True + async def _run_with_retries(self, func: Callable[[], Awaitable[T]]) -> T: + attempts = 0 + while True: + try: + return await func() + except Exception: + attempts += 1 + if self.max_retry_attempts != -1 and attempts > self.max_retry_attempts: + raise + backoff = self.retry_backoff_seconds_base * (2 ** (attempts - 1)) + await asyncio.sleep(backoff) + async def connect(self): """Connect to the server.""" try: @@ -267,15 +290,17 @@ async def list_tools( """List the tools available on the server.""" if not self.session: raise UserError("Server not initialized. Make sure you call `connect()` first.") + session = self.session + assert session is not None # Return from cache if caching is enabled, we have tools, and the cache is not dirty if self.cache_tools_list and not self._cache_dirty and self._tools_list: tools = self._tools_list else: - # Reset the cache dirty to False - self._cache_dirty = False # Fetch the tools from the server - self._tools_list = (await self.session.list_tools()).tools + result = await self._run_with_retries(lambda: session.list_tools()) + self._tools_list = result.tools + self._cache_dirty = False tools = self._tools_list # Filter tools based on tool_filter @@ -290,8 +315,10 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C """Invoke a tool on the server.""" if not self.session: raise UserError("Server not initialized. Make sure you call `connect()` first.") + session = self.session + assert session is not None - return await self.session.call_tool(tool_name, arguments) + return await self._run_with_retries(lambda: session.call_tool(tool_name, arguments)) async def list_prompts( self, @@ -365,6 +392,8 @@ def __init__( client_session_timeout_seconds: float | None = 5, tool_filter: ToolFilter = None, use_structured_content: bool = False, + max_retry_attempts: int = 0, + retry_backoff_seconds_base: float = 1.0, ): """Create a new MCP server based on the stdio transport. @@ -388,12 +417,18 @@ def __init__( include the structured content in the `tool_result.content`, and using it by default will cause duplicate content. You can set this to True if you know the server will not duplicate the structured content in the `tool_result.content`. + max_retry_attempts: Number of times to retry failed list_tools/call_tool calls. + Defaults to no retries. + retry_backoff_seconds_base: The base delay, in seconds, for exponential + backoff between retries. """ super().__init__( cache_tools_list, client_session_timeout_seconds, tool_filter, use_structured_content, + max_retry_attempts, + retry_backoff_seconds_base, ) self.params = StdioServerParameters( @@ -455,6 +490,8 @@ def __init__( client_session_timeout_seconds: float | None = 5, tool_filter: ToolFilter = None, use_structured_content: bool = False, + max_retry_attempts: int = 0, + retry_backoff_seconds_base: float = 1.0, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -480,12 +517,18 @@ def __init__( include the structured content in the `tool_result.content`, and using it by default will cause duplicate content. You can set this to True if you know the server will not duplicate the structured content in the `tool_result.content`. + max_retry_attempts: Number of times to retry failed list_tools/call_tool calls. + Defaults to no retries. + retry_backoff_seconds_base: The base delay, in seconds, for exponential + backoff between retries. """ super().__init__( cache_tools_list, client_session_timeout_seconds, tool_filter, use_structured_content, + max_retry_attempts, + retry_backoff_seconds_base, ) self.params = params @@ -547,6 +590,8 @@ def __init__( client_session_timeout_seconds: float | None = 5, tool_filter: ToolFilter = None, use_structured_content: bool = False, + max_retry_attempts: int = 0, + retry_backoff_seconds_base: float = 1.0, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -573,12 +618,18 @@ def __init__( include the structured content in the `tool_result.content`, and using it by default will cause duplicate content. You can set this to True if you know the server will not duplicate the structured content in the `tool_result.content`. + max_retry_attempts: Number of times to retry failed list_tools/call_tool calls. + Defaults to no retries. + retry_backoff_seconds_base: The base delay, in seconds, for exponential + backoff between retries. """ super().__init__( cache_tools_list, client_session_timeout_seconds, tool_filter, use_structured_content, + max_retry_attempts, + retry_backoff_seconds_base, ) self.params = params diff --git a/tests/mcp/test_client_session_retries.py b/tests/mcp/test_client_session_retries.py new file mode 100644 index 000000000..4cc292a3a --- /dev/null +++ b/tests/mcp/test_client_session_retries.py @@ -0,0 +1,64 @@ +from typing import cast + +import pytest +from mcp import ClientSession, Tool as MCPTool +from mcp.types import CallToolResult, ListToolsResult + +from agents.mcp.server import _MCPServerWithClientSession + + +class DummySession: + def __init__(self, fail_call_tool: int = 0, fail_list_tools: int = 0): + self.fail_call_tool = fail_call_tool + self.fail_list_tools = fail_list_tools + self.call_tool_attempts = 0 + self.list_tools_attempts = 0 + + async def call_tool(self, tool_name, arguments): + self.call_tool_attempts += 1 + if self.call_tool_attempts <= self.fail_call_tool: + raise RuntimeError("call_tool failure") + return CallToolResult(content=[]) + + async def list_tools(self): + self.list_tools_attempts += 1 + if self.list_tools_attempts <= self.fail_list_tools: + raise RuntimeError("list_tools failure") + return ListToolsResult(tools=[MCPTool(name="tool", inputSchema={})]) + + +class DummyServer(_MCPServerWithClientSession): + def __init__(self, session: DummySession, retries: int): + super().__init__( + cache_tools_list=False, + client_session_timeout_seconds=None, + max_retry_attempts=retries, + retry_backoff_seconds_base=0, + ) + self.session = cast(ClientSession, session) + + def create_streams(self): + raise NotImplementedError + + @property + def name(self) -> str: + return "dummy" + + +@pytest.mark.asyncio +async def test_call_tool_retries_until_success(): + session = DummySession(fail_call_tool=2) + server = DummyServer(session=session, retries=2) + result = await server.call_tool("tool", None) + assert isinstance(result, CallToolResult) + assert session.call_tool_attempts == 3 + + +@pytest.mark.asyncio +async def test_list_tools_unlimited_retries(): + session = DummySession(fail_list_tools=3) + server = DummyServer(session=session, retries=-1) + tools = await server.list_tools() + assert len(tools) == 1 + assert tools[0].name == "tool" + assert session.list_tools_attempts == 4