diff --git a/docs/visualization.md b/docs/visualization.md index 409803f76..1fcea7743 100644 --- a/docs/visualization.md +++ b/docs/visualization.md @@ -15,13 +15,17 @@ pip install "openai-agents[viz]" You can generate an agent visualization using the `draw_graph` function. This function creates a directed graph where: - **Agents** are represented as yellow boxes. +- **MCP Servers** are represented as grey boxes. - **Tools** are represented as green ellipses. - **Handoffs** are directed edges from one agent to another. ### Example Usage ```python +import os + from agents import Agent, function_tool +from agents.mcp.server import MCPServerStdio from agents.extensions.visualization import draw_graph @function_tool @@ -38,11 +42,22 @@ english_agent = Agent( instructions="You only speak English", ) +current_dir = os.path.dirname(os.path.abspath(__file__)) +samples_dir = os.path.join(current_dir, "sample_files") +mcp_server = MCPServerStdio( + name="Filesystem Server, via npx", + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, +) + triage_agent = Agent( name="Triage agent", instructions="Handoff to the appropriate agent based on the language of the request.", handoffs=[spanish_agent, english_agent], tools=[get_weather], + mcp_servers=[mcp_server], ) draw_graph(triage_agent) @@ -60,9 +75,11 @@ The generated graph includes: - A **start node** (`__start__`) indicating the entry point. - Agents represented as **rectangles** with yellow fill. - Tools represented as **ellipses** with green fill. +- MCP Servers represented as **rectangles** with grey fill. - Directed edges indicating interactions: - **Solid arrows** for agent-to-agent handoffs. - **Dotted arrows** for tool invocations. + - **Dashed arrows** for MCP server invocations. - An **end node** (`__end__`) indicating where execution terminates. ## Customizing the Graph diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index be762a330..67ca7d267 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -71,6 +71,12 @@ def get_all_nodes( f"fillcolor=lightgreen, width=0.5, height=0.3];" ) + for mcp_server in agent.mcp_servers: + parts.append( + f'"{mcp_server.name}" [label="{mcp_server.name}", shape=box, style=filled, ' + f"fillcolor=lightgrey, width=1, height=0.5];" + ) + for handoff in agent.handoffs: if isinstance(handoff, Handoff): parts.append( @@ -119,6 +125,11 @@ def get_all_edges( "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5]; "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""") + for mcp_server in agent.mcp_servers: + parts.append(f""" + "{agent.name}" -> "{mcp_server.name}" [style=dashed, penwidth=1.5]; + "{mcp_server.name}" -> "{agent.name}" [style=dashed, penwidth=1.5];""") + for handoff in agent.handoffs: if isinstance(handoff, Handoff): parts.append(f""" diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 8bce897e9..89211cc9c 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,3 +1,4 @@ +import sys from unittest.mock import Mock import graphviz # type: ignore @@ -12,6 +13,9 @@ ) from agents.handoffs import Handoff +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + @pytest.fixture def mock_agent(): @@ -27,6 +31,10 @@ def mock_agent(): agent.name = "Agent1" agent.tools = [tool1, tool2] agent.handoffs = [handoff1] + agent.mcp_servers = [] + + if sys.version_info >= (3, 10): + agent.mcp_servers = [FakeMCPServer(server_name="MCPServer1")] return agent @@ -62,6 +70,7 @@ def test_get_main_graph(mock_agent): '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in result ) + _assert_mcp_nodes(result) def test_get_all_nodes(mock_agent): @@ -90,6 +99,7 @@ def test_get_all_nodes(mock_agent): '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in result ) + _assert_mcp_nodes(result) def test_get_all_edges(mock_agent): @@ -101,6 +111,7 @@ def test_get_all_edges(mock_agent): assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result assert '"Agent1" -> "Handoff1";' in result + _assert_mcp_edges(result) def test_draw_graph(mock_agent): @@ -134,6 +145,25 @@ def test_draw_graph(mock_agent): '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source ) + _assert_mcp_nodes(graph.source) + + +def _assert_mcp_nodes(source: str): + if sys.version_info < (3, 10): + assert "MCPServer1" not in source + return + assert ( + '"MCPServer1" [label="MCPServer1", shape=box, style=filled, ' + "fillcolor=lightgrey, width=1, height=0.5];" in source + ) + + +def _assert_mcp_edges(source: str): + if sys.version_info < (3, 10): + assert "MCPServer1" not in source + return + assert '"Agent1" -> "MCPServer1" [style=dashed, penwidth=1.5];' in source + assert '"MCPServer1" -> "Agent1" [style=dashed, penwidth=1.5];' in source def test_cycle_detection():