diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 888e262..305b403 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str: return "".join(parts) -def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_nodes( + agent: Agent, parent: Optional[Agent] = None, visited: Optional[set[str]] = None +) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. @@ -41,6 +43,12 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the nodes. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] # Start and end the graph @@ -76,12 +84,14 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: f"shape=box, style=filled, style=rounded, " f"fillcolor=lightyellow, width=1.5, height=0.8];" ) - parts.append(get_all_nodes(handoff)) + parts.append(get_all_nodes(handoff, agent, visited)) return "".join(parts) -def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_edges( + agent: Agent, parent: Optional[Agent] = None, visited: Optional[set[str]] = None +) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. @@ -92,6 +102,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the edges. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] if not parent: @@ -109,7 +125,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: if isinstance(handoff, Agent): parts.append(f""" "{agent.name}" -> "{handoff.name}";""") - parts.append(get_all_edges(handoff, agent)) + parts.append(get_all_edges(handoff, agent, visited)) if not agent.handoffs and not isinstance(agent, Tool): # type: ignore parts.append(f'"{agent.name}" -> "__end__";') diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 6aa8677..9cc3ee9 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -134,3 +134,25 @@ def test_draw_graph(mock_agent): '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source ) + + +@pytest.fixture +def circular_agents() -> Agent: + agent1 = Agent(name="Agent1") + agent2 = Agent(name="Agent2") + agent1.handoffs = [agent2] + agent2.handoffs = [agent1] + return agent1 + + +def test_get_all_nodes_handles_cycle(circular_agents: Agent) -> None: + result = get_all_nodes(circular_agents) + assert '"Agent1" [label="Agent1"' in result + assert '"Agent2" [label="Agent2"' in result + + +def test_get_all_edges_handles_cycle(circular_agents: Agent) -> None: + result = get_all_edges(circular_agents) + assert result.count('"Agent1" -> "Agent2";') == 1 + assert result.count('"Agent2" -> "Agent1";') == 1 + assert '"Agent1" -> "__end__";' not in result