Skip to content

Commit

Permalink
FunctionTool partial support (#5183)
Browse files Browse the repository at this point in the history
<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

FunctionTool supports passing in a partial

## Related issue number

Closes #5151 

## Checks

- [x] I've included any doc changes needed for
https://microsoft.github.io/autogen/. See
https://microsoft.github.io/autogen/docs/Contribute#documentation to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.
  • Loading branch information
nour-bouzid authored Jan 29, 2025
1 parent 2f1684b commit 02e968a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import inspect
import typing
from functools import partial
from logging import getLogger
from typing import (
Annotated,
Expand Down Expand Up @@ -41,7 +42,8 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
type_hints = typing.get_type_hints(call, globalns, include_extras=True)
func_call = call.func if isinstance(call, partial) else call
type_hints = typing.get_type_hints(func_call, globalns, include_extras=True)
typed_params = [
inspect.Parameter(
name=param.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
self._func = func
self._global_imports = global_imports
signature = get_typed_signature(func)
func_name = name or func.__name__
func_name = name or func.func.__name__ if isinstance(func, functools.partial) else name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", signature)
return_type = signature.return_annotation
self._has_cancellation_support = "cancellation_token" in signature.parameters
Expand Down
62 changes: 62 additions & 0 deletions python/packages/autogen-core/tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from functools import partial
from typing import Annotated, List

import pytest
Expand Down Expand Up @@ -109,6 +110,67 @@ def my_function(arg: str = "default") -> MyResult:
assert "required" not in schema["parameters"]


def test_func_tool_with_partial_positional_arguments_schema_generation() -> None:
"""Test correct schema generation for a partial function with positional arguments."""

def get_weather(country: str, city: str) -> str:
return f"The temperature in {city}, {country} is 75°"

partial_function = partial(get_weather, "Germany")
tool = FunctionTool(partial_function, description="Partial function tool.")
schema = tool.schema

assert schema["name"] == "get_weather"
assert "description" in schema
assert schema["description"] == "Partial function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"city"}
assert schema["parameters"]["properties"]["city"]["type"] == "string"
assert schema["parameters"]["properties"]["city"]["description"] == "city"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["city"]
assert "country" not in schema["parameters"]["properties"] # check country not in schema params
assert len(schema["parameters"]["properties"]) == 1


def test_func_call_tool_with_kwargs_schema_generation() -> None:
"""Test correct schema generation for a partial function with kwargs."""

def get_weather(country: str, city: str) -> str:
return f"The temperature in {city}, {country} is 75°"

partial_function = partial(get_weather, country="Germany")
tool = FunctionTool(partial_function, description="Partial function tool.")
schema = tool.schema

assert schema["name"] == "get_weather"
assert "description" in schema
assert schema["description"] == "Partial function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"country", "city"}
assert schema["parameters"]["properties"]["city"]["type"] == "string"
assert schema["parameters"]["properties"]["country"]["type"] == "string"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["city"] # only city is required
assert len(schema["parameters"]["properties"]) == 2


@pytest.mark.asyncio
async def test_run_func_call_tool_with_kwargs_and_args() -> None:
"""Test run partial function with kwargs and args."""

def get_weather(country: str, city: str, unit: str = "Celsius") -> str:
return f"The temperature in {city}, {country} is 75° {unit}"

partial_function = partial(get_weather, "Germany", unit="Fahrenheit")
tool = FunctionTool(partial_function, description="Partial function tool.")
result = await tool.run_json({"city": "Berlin"}, CancellationToken())
assert isinstance(result, str)
assert result == "The temperature in Berlin, Germany is 75° Fahrenheit"


@pytest.mark.asyncio
async def test_tool_run() -> None:
tool = MyTool()
Expand Down

0 comments on commit 02e968a

Please sign in to comment.