Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3267,10 +3267,15 @@ async def a_generate_reply(
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages_before_reply(messages)

# Get sync functions to skip (those with async equivalents)
sync_to_skip = self._get_sync_funcs_to_skip_in_async_chat()

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if reply_func in exclude:
continue
if reply_func in sync_to_skip:
continue

if self._match_trigger(reply_func_tuple["trigger"], sender):
if is_coroutine_callable(reply_func):
Expand All @@ -3286,6 +3291,30 @@ async def a_generate_reply(
return reply
return self._default_auto_reply

def _get_sync_funcs_to_skip_in_async_chat(self) -> set[Callable[..., Any]]:
"""Get sync reply functions that should be skipped in async chat.

When an async reply function is registered with ignore_async_in_sync_chat=True,
it indicates that a sync equivalent exists. In async chat, we should skip the
sync version and use the async version instead.

Returns:
A set of sync reply functions that have async equivalents.
"""
sync_to_skip: set[Callable[..., Any]] = set()
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if is_coroutine_callable(reply_func) and reply_func_tuple.get("ignore_async_in_sync_chat"):
func_name = reply_func.__name__
if func_name.startswith("a_"):
sync_name = func_name[2:] # Remove "a_" prefix
for other_tuple in self._reply_func_list:
other_func = other_tuple["reply_func"]
if not is_coroutine_callable(other_func) and other_func.__name__ == sync_name:
sync_to_skip.add(other_func)
break
return sync_to_skip

def _match_trigger(self, trigger: None | str | type | Agent | Callable | list, sender: Agent | None) -> bool:
"""Check if the sender matches the trigger.

Expand Down
132 changes: 132 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2245,3 +2245,135 @@ def runtime_tool(message: str) -> str:
assert len(executor.function_map) == 2
assert "pre_tool" in executor.function_map
assert "runtime_tool" in executor.function_map


class TestAsyncReplyFunctionSkipping:
"""Tests for skipping sync reply functions in async chat when async equivalents exist."""

def test_get_sync_funcs_to_skip_in_async_chat(self):
"""Test that _get_sync_funcs_to_skip_in_async_chat identifies correct sync functions."""
agent = ConversableAgent(name="test", llm_config=False, human_input_mode="NEVER")

sync_to_skip = agent._get_sync_funcs_to_skip_in_async_chat()

# Should identify all sync functions that have async equivalents
sync_names = {f.__name__ for f in sync_to_skip}
expected = {
"check_termination_and_human_reply",
"generate_oai_reply",
"generate_tool_calls_reply",
"generate_function_call_reply",
}
assert sync_names == expected

def test_get_sync_funcs_to_skip_excludes_sync_only_functions(self):
"""Test that sync-only functions (no async equivalent) are not in skip set."""
agent = ConversableAgent(
name="test",
llm_config=False,
human_input_mode="NEVER",
code_execution_config={"executor": "commandline-local"},
)

sync_to_skip = agent._get_sync_funcs_to_skip_in_async_chat()
sync_names = {f.__name__ for f in sync_to_skip}

# Code execution reply is sync-only, should NOT be in skip set
assert "_generate_code_execution_reply_using_executor" not in sync_names

def test_custom_sync_only_function_not_skipped(self):
"""Test that user-registered sync-only functions are not skipped in async chat."""
agent = ConversableAgent(name="test", llm_config=False, human_input_mode="NEVER")

def custom_sync_reply(recipient, messages, sender, config):
return (True, "custom sync response")

agent.register_reply([autogen.Agent, None], custom_sync_reply)

sync_to_skip = agent._get_sync_funcs_to_skip_in_async_chat()

# Custom sync function should NOT be in skip set
assert custom_sync_reply not in sync_to_skip

def test_custom_async_with_sync_equivalent_skips_sync(self):
"""Test that when user registers async with ignore_async_in_sync_chat, sync is skipped."""
agent = ConversableAgent(name="test", llm_config=False, human_input_mode="NEVER")

def my_reply(recipient, messages, sender, config):
return (True, "sync response")

async def a_my_reply(recipient, messages, sender, config):
return (True, "async response")

# Register both with the naming convention and flag
agent.register_reply([autogen.Agent, None], my_reply)
agent.register_reply([autogen.Agent, None], a_my_reply, ignore_async_in_sync_chat=True)

sync_to_skip = agent._get_sync_funcs_to_skip_in_async_chat()
sync_names = {f.__name__ for f in sync_to_skip}

# my_reply should be in skip set because a_my_reply has ignore_async_in_sync_chat=True
assert "my_reply" in sync_names

@pytest.mark.asyncio
async def test_a_generate_reply_skips_sync_with_async_equivalent(self):
"""Test that a_generate_reply skips sync functions when async equivalents exist."""
agent = ConversableAgent(name="test", llm_config=False, human_input_mode="NEVER")
sender = ConversableAgent(name="sender", llm_config=False)
agent._oai_messages[sender] = [{"role": "user", "content": "hello"}]

# Reset counter to verify it's only incremented once
agent._consecutive_auto_reply_counter[sender] = 0

await agent.a_generate_reply(sender=sender)

# Counter should be incremented exactly once (by async version only)
# If sync was also called, it would be 2
assert agent._consecutive_auto_reply_counter[sender] == 1

@pytest.mark.asyncio
async def test_a_generate_reply_calls_sync_only_functions(self):
"""Test that a_generate_reply still calls sync-only functions."""
agent = ConversableAgent(name="test", llm_config=False, human_input_mode="NEVER")
sender = ConversableAgent(name="sender", llm_config=False)

sync_only_called = []

def sync_only_reply(recipient, messages, sender, config):
sync_only_called.append(True)
return (False, None) # Don't finalize, let other functions run

agent.register_reply([autogen.Agent, None], sync_only_reply)
agent._oai_messages[sender] = [{"role": "user", "content": "hello"}]

await agent.a_generate_reply(sender=sender)

# sync_only_reply should have been called
assert len(sync_only_called) == 1

@pytest.mark.asyncio
async def test_a_generate_reply_prefers_async_over_sync(self):
"""Test that when both sync and async exist, only async is called."""
agent = ConversableAgent(name="test", llm_config=False, human_input_mode="NEVER")
sender = ConversableAgent(name="sender", llm_config=False)

call_log = []

def my_reply(recipient, messages, sender, config):
call_log.append("sync")
return (True, "sync response")

async def a_my_reply(recipient, messages, sender, config):
call_log.append("async")
return (True, "async response")

# Register both with naming convention and flag
agent.register_reply([autogen.Agent, None], my_reply)
agent.register_reply([autogen.Agent, None], a_my_reply, ignore_async_in_sync_chat=True)
agent._oai_messages[sender] = [{"role": "user", "content": "hello"}]

result = await agent.a_generate_reply(sender=sender)

# Only async should be called
assert call_log == ["async"]
assert result == "async response"
Loading