Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Update sort step method for assistant invoke. #10191

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions python/semantic_kernel/agents/open_ai/open_ai_assistant_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,8 @@ async def _invoke_internal(
# Filter out None values to avoid passing them as kwargs
run_options = {k: v for k, v in run_options.items() if v is not None}

logger.info(f"Starting invoke for agent `{self.name}` and thread `{thread_id}`")

run = await self.client.beta.threads.runs.create(
assistant_id=self.assistant.id,
thread_id=thread_id,
Expand All @@ -755,8 +757,13 @@ async def _invoke_internal(

# Check if function calling required
if run.status == "requires_action":
logger.info(f"Run [{run.id}] requires action for agent `{self.name}` and thread `{thread_id}`")
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
fccs = get_function_call_contents(run, function_steps)
if fccs:
logger.info(
f"Yielding `generate_function_call_content` for agent `{self.name}` and "
f"thread `{thread_id}`, visibility False"
)
yield False, generate_function_call_content(agent_name=self.name, fccs=fccs)

from semantic_kernel.contents.chat_history import ChatHistory
Expand All @@ -770,28 +777,52 @@ async def _invoke_internal(
thread_id=thread_id,
tool_outputs=tool_outputs, # type: ignore
)
logger.info(f"Submitted tool outputs for agent `{self.name}` and thread `{thread_id}`")

steps_response = await self.client.beta.threads.runs.steps.list(run_id=run.id, thread_id=thread_id)
logger.info(f"Called for steps_response for run [{run.id}] agent `{self.name}` and thread `{thread_id}`")
steps: list[RunStep] = steps_response.data
completed_steps_to_process: list[RunStep] = sorted(
[s for s in steps if s.completed_at is not None and s.id not in processed_step_ids],
key=lambda s: s.created_at,

def sort_key(step: RunStep):
# Put tool_calls first, then message_creation
# If multiple steps share a type, break ties by completed_at
return (0 if step.type == "tool_calls" else 1, step.completed_at)

completed_steps_to_process = sorted(
[s for s in steps if s.completed_at is not None and s.id not in processed_step_ids], key=sort_key
)

logger.info(
f"Completed steps to process for run [{run.id}] agent `{self.name}` and thread `{thread_id}` "
f"with length `{len(completed_steps_to_process)}`"
)

message_count = 0
for completed_step in completed_steps_to_process:
if completed_step.type == "tool_calls":
logger.info(
f"Entering step type tool_calls for run [{run.id}], agent `{self.name}` and "
f"thread `{thread_id}`"
)
assert hasattr(completed_step.step_details, "tool_calls") # nosec
for tool_call in completed_step.step_details.tool_calls:
is_visible = False
content: "ChatMessageContent | None" = None
if tool_call.type == "code_interpreter":
logger.info(
f"Entering step type tool_calls for run [{run.id}], [code_interpreter] for "
f"agent `{self.name}` and thread `{thread_id}`"
)
content = generate_code_interpreter_content(
self.name,
tool_call.code_interpreter.input, # type: ignore
)
is_visible = True
elif tool_call.type == "function":
logger.info(
f"Entering step type tool_calls for run [{run.id}], [function] for agent `{self.name}` "
f"and thread `{thread_id}`"
)
function_step = function_steps.get(tool_call.id)
assert function_step is not None # nosec
content = generate_function_result_content(
Expand All @@ -800,8 +831,16 @@ async def _invoke_internal(

if content:
message_count += 1
logger.info(
f"Yielding tool_message for run [{run.id}], agent `{self.name}` and thread "
f"`{thread_id}` and message count `{message_count}`, is_visible `{is_visible}`"
)
yield is_visible, content
elif completed_step.type == "message_creation":
logger.info(
f"Entering step type message_creation for run [{run.id}], agent `{self.name}` and "
f"thread `{thread_id}`"
)
message = await self._retrieve_message(
thread_id=thread_id,
message_id=completed_step.step_details.message_creation.message_id, # type: ignore
Expand All @@ -810,6 +849,10 @@ async def _invoke_internal(
content = generate_message_content(self.name, message)
if content and len(content.items) > 0:
message_count += 1
logger.info(
f"Yielding message_creation for run [{run.id}], agent `{self.name}` and "
f"thread `{thread_id}` and message count `{message_count}`, is_visible `{True}`"
)
yield True, content
processed_step_ids.add(completed_step.id)

Expand Down
87 changes: 87 additions & 0 deletions python/tests/unit/agents/test_open_ai_assistant_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,35 @@ def __init__(self):
)


@pytest.fixture
def mock_run_step_function_tool_call():
class MockToolCall:
def __init__(self):
self.type = "function"

return RunStep(
id="step_id_1",
type="tool_calls",
completed_at=int(datetime.now(timezone.utc).timestamp()),
created_at=int((datetime.now(timezone.utc) - timedelta(minutes=1)).timestamp()),
step_details=ToolCallsStepDetails(
tool_calls=[
FunctionToolCall(
type="function",
id="tool_call_id",
function=RunsFunction(arguments="{}", name="function_name", outpt="test output"),
),
],
type="tool_calls",
),
assistant_id="assistant_id",
object="thread.run.step",
run_id="run_id",
status="completed",
thread_id="thread_id",
)


@pytest.fixture
def mock_run_step_message_creation():
class MockMessageCreation:
Expand Down Expand Up @@ -1206,6 +1235,64 @@ def mock_get_function_call_contents(run, function_steps):
_ = [message async for message in azure_openai_assistant_agent.invoke("thread_id")]


async def test_invoke_order(
azure_openai_assistant_agent,
mock_assistant,
mock_run_required_action,
mock_run_step_function_tool_call,
mock_run_step_message_creation,
mock_thread_messages,
mock_function_call_content,
):
poll_count = 0

async def mock_poll_run_status(run, thread_id):
nonlocal poll_count
if run.status == "requires_action":
if poll_count == 0:
pass
else:
run.status = "completed"
poll_count += 1
return run

def mock_get_function_call_contents(run, function_steps):
function_call_content = mock_function_call_content
function_call_content.id = "tool_call_id"
function_steps[function_call_content.id] = function_call_content
return [function_call_content]

azure_openai_assistant_agent.assistant = mock_assistant
azure_openai_assistant_agent._poll_run_status = AsyncMock(side_effect=mock_poll_run_status)
azure_openai_assistant_agent._retrieve_message = AsyncMock(return_value=mock_thread_messages[0])

with patch(
"semantic_kernel.agents.open_ai.assistant_content_generation.get_function_call_contents",
side_effect=mock_get_function_call_contents,
):
client = azure_openai_assistant_agent.client

with patch.object(client.beta.threads.runs, "create", new_callable=AsyncMock) as mock_runs_create:
mock_runs_create.return_value = mock_run_required_action

with (
patch.object(client.beta.threads.runs, "submit_tool_outputs", new_callable=AsyncMock),
patch.object(client.beta.threads.runs.steps, "list", new_callable=AsyncMock) as mock_steps_list,
):
mock_steps_list.return_value = MagicMock(
data=[mock_run_step_message_creation, mock_run_step_function_tool_call]
)

messages = []
async for _, content in azure_openai_assistant_agent._invoke_internal("thread_id"):
messages.append(content)

assert len(messages) == 3
assert isinstance(messages[0].items[0], FunctionCallContent)
assert isinstance(messages[1].items[0], FunctionResultContent)
assert isinstance(messages[2].items[0], TextContent)


async def test_invoke_stream(
azure_openai_assistant_agent,
mock_assistant,
Expand Down
Loading
Loading