Skip to content

feat: add retry logic to MCP server operations #1554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import abc
import asyncio
import inspect
from collections.abc import Awaitable
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
Expand All @@ -21,6 +22,8 @@
from ..run_context import RunContextWrapper
from .util import ToolFilter, ToolFilterContext, ToolFilterStatic

T = TypeVar("T")

if TYPE_CHECKING:
from ..agent import AgentBase

Expand Down Expand Up @@ -98,6 +101,8 @@ def __init__(
client_session_timeout_seconds: float | None,
tool_filter: ToolFilter = None,
use_structured_content: bool = False,
max_retry_attempts: int = 0,
retry_backoff_seconds_base: float = 1.0,
):
"""
Args:
Expand All @@ -115,6 +120,10 @@ def __init__(
include the structured content in the `tool_result.content`, and using it by
default will cause duplicate content. You can set this to True if you know the
server will not duplicate the structured content in the `tool_result.content`.
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
Defaults to no retries.
retry_backoff_seconds_base: The base delay, in seconds, used for exponential
backoff between retries.
"""
super().__init__(use_structured_content=use_structured_content)
self.session: ClientSession | None = None
Expand All @@ -124,6 +133,8 @@ def __init__(
self.server_initialize_result: InitializeResult | None = None

self.client_session_timeout_seconds = client_session_timeout_seconds
self.max_retry_attempts = max_retry_attempts
self.retry_backoff_seconds_base = retry_backoff_seconds_base

# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
Expand Down Expand Up @@ -233,6 +244,18 @@ def invalidate_tools_cache(self):
"""Invalidate the tools cache."""
self._cache_dirty = True

async def _run_with_retries(self, func: Callable[[], Awaitable[T]]) -> T:
attempts = 0
while True:
try:
return await func()
except Exception:
attempts += 1
if self.max_retry_attempts != -1 and attempts > self.max_retry_attempts:
raise
backoff = self.retry_backoff_seconds_base * (2 ** (attempts - 1))
await asyncio.sleep(backoff)

async def connect(self):
"""Connect to the server."""
try:
Expand Down Expand Up @@ -267,15 +290,17 @@ async def list_tools(
"""List the tools available on the server."""
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")
session = self.session
assert session is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: Feels silly to assert is not None right below the early exception on if not self.session


# Return from cache if caching is enabled, we have tools, and the cache is not dirty
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
tools = self._tools_list
else:
# Reset the cache dirty to False
self._cache_dirty = False
# Fetch the tools from the server
self._tools_list = (await self.session.list_tools()).tools
result = await self._run_with_retries(lambda: session.list_tools())
self._tools_list = result.tools
self._cache_dirty = False
tools = self._tools_list

# Filter tools based on tool_filter
Expand All @@ -290,8 +315,10 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
"""Invoke a tool on the server."""
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")
session = self.session
assert session is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same super nit


return await self.session.call_tool(tool_name, arguments)
return await self._run_with_retries(lambda: session.call_tool(tool_name, arguments))

async def list_prompts(
self,
Expand Down Expand Up @@ -365,6 +392,8 @@ def __init__(
client_session_timeout_seconds: float | None = 5,
tool_filter: ToolFilter = None,
use_structured_content: bool = False,
max_retry_attempts: int = 0,
retry_backoff_seconds_base: float = 1.0,
):
"""Create a new MCP server based on the stdio transport.
Expand All @@ -388,12 +417,18 @@ def __init__(
include the structured content in the `tool_result.content`, and using it by
default will cause duplicate content. You can set this to True if you know the
server will not duplicate the structured content in the `tool_result.content`.
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
Defaults to no retries.
retry_backoff_seconds_base: The base delay, in seconds, for exponential
backoff between retries.
"""
super().__init__(
cache_tools_list,
client_session_timeout_seconds,
tool_filter,
use_structured_content,
max_retry_attempts,
retry_backoff_seconds_base,
)

self.params = StdioServerParameters(
Expand Down Expand Up @@ -455,6 +490,8 @@ def __init__(
client_session_timeout_seconds: float | None = 5,
tool_filter: ToolFilter = None,
use_structured_content: bool = False,
max_retry_attempts: int = 0,
retry_backoff_seconds_base: float = 1.0,
):
"""Create a new MCP server based on the HTTP with SSE transport.
Expand All @@ -480,12 +517,18 @@ def __init__(
include the structured content in the `tool_result.content`, and using it by
default will cause duplicate content. You can set this to True if you know the
server will not duplicate the structured content in the `tool_result.content`.
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
Defaults to no retries.
retry_backoff_seconds_base: The base delay, in seconds, for exponential
backoff between retries.
"""
super().__init__(
cache_tools_list,
client_session_timeout_seconds,
tool_filter,
use_structured_content,
max_retry_attempts,
retry_backoff_seconds_base,
)

self.params = params
Expand Down Expand Up @@ -547,6 +590,8 @@ def __init__(
client_session_timeout_seconds: float | None = 5,
tool_filter: ToolFilter = None,
use_structured_content: bool = False,
max_retry_attempts: int = 0,
retry_backoff_seconds_base: float = 1.0,
):
"""Create a new MCP server based on the Streamable HTTP transport.
Expand All @@ -573,12 +618,18 @@ def __init__(
include the structured content in the `tool_result.content`, and using it by
default will cause duplicate content. You can set this to True if you know the
server will not duplicate the structured content in the `tool_result.content`.
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
Defaults to no retries.
retry_backoff_seconds_base: The base delay, in seconds, for exponential
backoff between retries.
"""
super().__init__(
cache_tools_list,
client_session_timeout_seconds,
tool_filter,
use_structured_content,
max_retry_attempts,
retry_backoff_seconds_base,
)

self.params = params
Expand Down
64 changes: 64 additions & 0 deletions tests/mcp/test_client_session_retries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import cast

import pytest
from mcp import ClientSession, Tool as MCPTool
from mcp.types import CallToolResult, ListToolsResult

from agents.mcp.server import _MCPServerWithClientSession


class DummySession:
def __init__(self, fail_call_tool: int = 0, fail_list_tools: int = 0):
self.fail_call_tool = fail_call_tool
self.fail_list_tools = fail_list_tools
self.call_tool_attempts = 0
self.list_tools_attempts = 0

async def call_tool(self, tool_name, arguments):
self.call_tool_attempts += 1
if self.call_tool_attempts <= self.fail_call_tool:
raise RuntimeError("call_tool failure")
return CallToolResult(content=[])

async def list_tools(self):
self.list_tools_attempts += 1
if self.list_tools_attempts <= self.fail_list_tools:
raise RuntimeError("list_tools failure")
return ListToolsResult(tools=[MCPTool(name="tool", inputSchema={})])


class DummyServer(_MCPServerWithClientSession):
def __init__(self, session: DummySession, retries: int):
super().__init__(
cache_tools_list=False,
client_session_timeout_seconds=None,
max_retry_attempts=retries,
retry_backoff_seconds_base=0,
)
self.session = cast(ClientSession, session)

def create_streams(self):
raise NotImplementedError

@property
def name(self) -> str:
return "dummy"


@pytest.mark.asyncio
async def test_call_tool_retries_until_success():
session = DummySession(fail_call_tool=2)
server = DummyServer(session=session, retries=2)
result = await server.call_tool("tool", None)
assert isinstance(result, CallToolResult)
assert session.call_tool_attempts == 3


@pytest.mark.asyncio
async def test_list_tools_unlimited_retries():
session = DummySession(fail_list_tools=3)
server = DummyServer(session=session, retries=-1)
tools = await server.list_tools()
assert len(tools) == 1
assert tools[0].name == "tool"
assert session.list_tools_attempts == 4