Skip to content

feat(viz): draw MCP servers #1368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ pip install "openai-agents[viz]"
You can generate an agent visualization using the `draw_graph` function. This function creates a directed graph where:

- **Agents** are represented as yellow boxes.
- **MCP Servers** are represented as grey boxes.
- **Tools** are represented as green ellipses.
- **Handoffs** are directed edges from one agent to another.

### Example Usage

```python
import os

from agents import Agent, function_tool
from agents.mcp.server import MCPServerStdio
from agents.extensions.visualization import draw_graph

@function_tool
Expand All @@ -38,11 +42,22 @@ english_agent = Agent(
instructions="You only speak English",
)

current_dir = os.path.dirname(os.path.abspath(__file__))
samples_dir = os.path.join(current_dir, "sample_files")
mcp_server = MCPServerStdio(
name="Filesystem Server, via npx",
params={
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir],
},
)

triage_agent = Agent(
name="Triage agent",
instructions="Handoff to the appropriate agent based on the language of the request.",
handoffs=[spanish_agent, english_agent],
tools=[get_weather],
mcp_servers=[mcp_server],
)

draw_graph(triage_agent)
Expand All @@ -60,9 +75,11 @@ The generated graph includes:
- A **start node** (`__start__`) indicating the entry point.
- Agents represented as **rectangles** with yellow fill.
- Tools represented as **ellipses** with green fill.
- MCP Servers represented as **rectangles** with grey fill.
- Directed edges indicating interactions:
- **Solid arrows** for agent-to-agent handoffs.
- **Dotted arrows** for tool invocations.
- **Dashed arrows** for MCP server invocations.
- An **end node** (`__end__`) indicating where execution terminates.

## Customizing the Graph
Expand Down
11 changes: 11 additions & 0 deletions src/agents/extensions/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def get_all_nodes(
f"fillcolor=lightgreen, width=0.5, height=0.3];"
)

for mcp_server in agent.mcp_servers:
parts.append(
f'"{mcp_server.name}" [label="{mcp_server.name}", shape=box, style=filled, '
f"fillcolor=lightgrey, width=1, height=0.5];"
)

for handoff in agent.handoffs:
if isinstance(handoff, Handoff):
parts.append(
Expand Down Expand Up @@ -119,6 +125,11 @@ def get_all_edges(
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")

for mcp_server in agent.mcp_servers:
parts.append(f"""
"{agent.name}" -> "{mcp_server.name}" [style=dashed, penwidth=1.5];
"{mcp_server.name}" -> "{agent.name}" [style=dashed, penwidth=1.5];""")

for handoff in agent.handoffs:
if isinstance(handoff, Handoff):
parts.append(f"""
Expand Down
30 changes: 30 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from unittest.mock import Mock

import graphviz # type: ignore
Expand All @@ -12,6 +13,9 @@
)
from agents.handoffs import Handoff

if sys.version_info >= (3, 10):
from .mcp.helpers import FakeMCPServer


@pytest.fixture
def mock_agent():
Expand All @@ -27,6 +31,10 @@ def mock_agent():
agent.name = "Agent1"
agent.tools = [tool1, tool2]
agent.handoffs = [handoff1]
agent.mcp_servers = []

if sys.version_info >= (3, 10):
agent.mcp_servers = [FakeMCPServer(server_name="MCPServer1")]

return agent

Expand Down Expand Up @@ -62,6 +70,7 @@ def test_get_main_graph(mock_agent):
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
_assert_mcp_nodes(result)


def test_get_all_nodes(mock_agent):
Expand Down Expand Up @@ -90,6 +99,7 @@ def test_get_all_nodes(mock_agent):
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
_assert_mcp_nodes(result)


def test_get_all_edges(mock_agent):
Expand All @@ -101,6 +111,7 @@ def test_get_all_edges(mock_agent):
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result
assert '"Agent1" -> "Handoff1";' in result
_assert_mcp_edges(result)


def test_draw_graph(mock_agent):
Expand Down Expand Up @@ -134,6 +145,25 @@ def test_draw_graph(mock_agent):
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
)
_assert_mcp_nodes(graph.source)


def _assert_mcp_nodes(source: str):
if sys.version_info < (3, 10):
assert "MCPServer1" not in source
return
assert (
'"MCPServer1" [label="MCPServer1", shape=box, style=filled, '
"fillcolor=lightgrey, width=1, height=0.5];" in source
)


def _assert_mcp_edges(source: str):
if sys.version_info < (3, 10):
assert "MCPServer1" not in source
return
assert '"Agent1" -> "MCPServer1" [style=dashed, penwidth=1.5];' in source
assert '"MCPServer1" -> "Agent1" [style=dashed, penwidth=1.5];' in source


def test_cycle_detection():
Expand Down