diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 597c15dfa..3824a3c18 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -1380,6 +1380,7 @@ def tool( exclude_args: list[str] | None = None, meta: dict[str, Any] | None = None, enabled: bool | None = None, + unpack_pydantic_args: bool = False, ) -> FunctionTool: ... @overload @@ -1397,6 +1398,7 @@ def tool( exclude_args: list[str] | None = None, meta: dict[str, Any] | None = None, enabled: bool | None = None, + unpack_pydantic_args: bool = False, ) -> Callable[[AnyFunction], FunctionTool]: ... def tool( @@ -1413,6 +1415,7 @@ def tool( exclude_args: list[str] | None = None, meta: dict[str, Any] | None = None, enabled: bool | None = None, + unpack_pydantic_args: bool = False, ) -> Callable[[AnyFunction], FunctionTool] | FunctionTool: """Decorator to register a tool. @@ -1501,6 +1504,7 @@ def my_tool(x: int) -> str: meta=meta, serializer=self._tool_serializer, enabled=enabled, + unpack_pydantic_args=unpack_pydantic_args, ) self.add_tool(tool) return tool @@ -1534,6 +1538,7 @@ def my_tool(x: int) -> str: exclude_args=exclude_args, meta=meta, enabled=enabled, + unpack_pydantic_args=unpack_pydantic_args, ) def add_resource(self, resource: Resource) -> Resource: diff --git a/src/fastmcp/tools/tool.py b/src/fastmcp/tools/tool.py index 552c4d729..febf58690 100644 --- a/src/fastmcp/tools/tool.py +++ b/src/fastmcp/tools/tool.py @@ -188,6 +188,7 @@ def from_function( serializer: ToolResultSerializerType | None = None, meta: dict[str, Any] | None = None, enabled: bool | None = None, + unpack_pydantic_args: bool = False, ) -> FunctionTool: """Create a Tool from a function.""" return FunctionTool.from_function( @@ -203,6 +204,7 @@ def from_function( serializer=serializer, meta=meta, enabled=enabled, + unpack_pydantic_args=unpack_pydantic_args, ) async def run(self, arguments: dict[str, Any]) -> ToolResult: @@ -254,6 +256,8 @@ def from_tool( class FunctionTool(Tool): fn: Callable[..., Any] + unpack_pydantic_args: bool = False + unpacked_models_map: dict | None = None @classmethod def from_function( @@ -270,6 +274,7 @@ def from_function( serializer: ToolResultSerializerType | None = None, meta: dict[str, Any] | None = None, enabled: bool | None = None, + unpack_pydantic_args: bool = False, ) -> FunctionTool: """Create a Tool from a function.""" if exclude_args and fastmcp.settings.deprecation_warnings: @@ -283,7 +288,11 @@ def from_function( stacklevel=2, ) - parsed_fn = ParsedFunction.from_function(fn, exclude_args=exclude_args) + parsed_fn = ParsedFunction.from_function( + fn, + exclude_args=exclude_args, + unpack_pydantic_args=unpack_pydantic_args, + ) if name is None and parsed_fn.name == "": raise ValueError("You must provide a name for lambda functions") @@ -325,6 +334,8 @@ def from_function( serializer=serializer, meta=meta, enabled=enabled if enabled is not None else True, + unpack_pydantic_args=unpack_pydantic_args, + unpacked_models_map=parsed_fn.unpacked_models_map, ) async def run(self, arguments: dict[str, Any]) -> ToolResult: @@ -333,6 +344,28 @@ async def run(self, arguments: dict[str, Any]) -> ToolResult: arguments = arguments.copy() + # If unpacking is enabled, re-assemble the Pydantic models from the arguments + if self.unpack_pydantic_args and self.unpacked_models_map: + assembled_args = {} + consumed_keys = set() + + for arg_name, model_cls in self.unpacked_models_map.items(): + model_fields = model_cls.model_fields.keys() + model_kwargs = {} + for field_name in model_fields: + if field_name in arguments: + model_kwargs[field_name] = arguments[field_name] + consumed_keys.add(field_name) + + assembled_args[arg_name] = model_cls(**model_kwargs) + + # Add the remaining non-pydantic arguments + for key, value in arguments.items(): + if key not in consumed_keys: + assembled_args[key] = value + + arguments = assembled_args + context_kwarg = find_kwarg_by_type(self.fn, kwarg_type=Context) if context_kwarg and context_kwarg not in arguments: arguments[context_kwarg] = get_context() @@ -400,6 +433,7 @@ class ParsedFunction: description: str | None input_schema: dict[str, Any] output_schema: dict[str, Any] | None + unpacked_models_map: dict[str, Any] | None = None @classmethod def from_function( @@ -408,6 +442,7 @@ def from_function( exclude_args: list[str] | None = None, validate: bool = True, wrap_non_object_output_schema: bool = True, + unpack_pydantic_args: bool = False, ) -> ParsedFunction: from fastmcp.server.context import Context @@ -453,12 +488,97 @@ def from_function( if exclude_args: prune_params.extend(exclude_args) + unpacked_models_map = {} + fn_for_schema = fn + + if unpack_pydantic_args: + import pydantic + + original_sig = inspect.signature(fn) + new_params = [] + has_pydantic_model = False + + seen_field_names: set[str] = set() + + # First pass: collect non-Pydantic param names + for param in original_sig.parameters.values(): + if not ( + isinstance(param.annotation, type) + and issubclass(param.annotation, pydantic.BaseModel) + ): + if param.name in seen_field_names: + raise ValueError(f"Duplicate parameter name: {param.name}") + seen_field_names.add(param.name) + + for param in original_sig.parameters.values(): + # Check if the parameter type is a Pydantic model + if isinstance(param.annotation, type) and issubclass( + param.annotation, pydantic.BaseModel + ): + has_pydantic_model = True + unpacked_models_map[param.name] = param.annotation + # Unpack the model's fields into new parameters + for field_name, field in param.annotation.model_fields.items(): + if field_name in seen_field_names: + raise ValueError( + f"Field name '{field_name}' from Pydantic model '{param.annotation.__name__}' " + f"conflicts with another parameter. Cannot unpack." + ) + seen_field_names.add(field_name) + + # Create a new parameter for each field + # Handle default vs default_factory + if field.is_required(): + default = inspect.Parameter.empty + elif field.default is not pydantic_core.PydanticUndefined: + default = field.default + else: + # Field has default_factory - cannot represent as static default + # Use empty to make it required in schema (factory runs at validation) + default = inspect.Parameter.empty + + new_param = inspect.Parameter( + name=field_name, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=default, + annotation=field.annotation, + ) + new_params.append(new_param) + else: + new_params.append(param) + + if has_pydantic_model: + # Sort parameters to ensure non-defaults come first + required_params = [ + p for p in new_params if p.default == inspect.Parameter.empty + ] + optional_params = [ + p for p in new_params if p.default != inspect.Parameter.empty + ] + new_params = required_params + optional_params + + # Create a new function with the unpacked signature for schema generation + def placeholder_fn(*args, **kwargs): ... + + new_sig = original_sig.replace(parameters=new_params) + new_annotations = { + p.name: p.annotation + for p in new_params + if p.annotation != inspect.Parameter.empty + } + + placeholder_fn.__signature__ = new_sig # type: ignore[attr-defined] + placeholder_fn.__annotations__ = new_annotations # type: ignore[attr-defined] + fn_for_schema = placeholder_fn + # Create a function without excluded parameters in annotations # This prevents Pydantic from trying to serialize non-serializable types # before we can exclude them in compress_schema - fn_for_typeadapter = fn + fn_for_typeadapter = fn_for_schema if prune_params: - fn_for_typeadapter = create_function_without_params(fn, prune_params) + fn_for_typeadapter = create_function_without_params( + fn_for_schema, prune_params + ) input_type_adapter = get_cached_typeadapter(fn_for_typeadapter) input_schema = input_type_adapter.json_schema() @@ -535,6 +655,7 @@ def from_function( description=fn_doc, input_schema=input_schema, output_schema=output_schema or None, + unpacked_models_map=unpacked_models_map or None, ) diff --git a/tests/test_unpack_pydantic_args.py b/tests/test_unpack_pydantic_args.py new file mode 100644 index 000000000..d7bb1c361 --- /dev/null +++ b/tests/test_unpack_pydantic_args.py @@ -0,0 +1,131 @@ +import pytest +from pydantic import BaseModel, Field + +from fastmcp import FastMCP + + +class User(BaseModel): + name: str = Field(description="The user's name") + age: int = Field(description="The user's age") + + +def test_unpack_pydantic_args(): + mcp = FastMCP("test") + + @mcp.tool(unpack_pydantic_args=True) + def greet_user(user: User, greeting: str = "Hello") -> str: + return f"{greeting}, {user.name}! You are {user.age} years old." + + tool = mcp._tool_manager._tools["greet_user"] + + # Check schema + schema = tool.parameters + assert "name" in schema["properties"] + assert "age" in schema["properties"] + assert "greeting" in schema["properties"] + assert "user" not in schema["properties"] + + # Check required fields + assert "name" in schema["required"] + assert "age" in schema["required"] + assert "greeting" not in schema.get("required", []) + + # Run tool + import asyncio + + result = asyncio.run(tool.run({"name": "Alice", "age": 30, "greeting": "Hi"})) + + assert result.content[0].text == "Hi, Alice! You are 30 years old." + + +def test_unpack_pydantic_args_nested(): + mcp = FastMCP("test") + + class Address(BaseModel): + city: str + zipcode: str + + class UserWithAddress(BaseModel): + name: str + address: Address + + # This feature currently only unpacks top-level Pydantic models. + # Nested models inside Pydantic models are kept as is (Pydantic handles them). + + @mcp.tool(unpack_pydantic_args=True) + def process_address(address: Address) -> str: + return f"{address.city} {address.zipcode}" + + tool = mcp._tool_manager._tools["process_address"] + schema = tool.parameters + + assert "city" in schema["properties"] + assert "zipcode" in schema["properties"] + + import asyncio + + result = asyncio.run(tool.run({"city": "New York", "zipcode": "10001"})) + assert result.content[0].text == "New York 10001" + + +def test_unpack_pydantic_args_collision(): + mcp = FastMCP("test") + + class User(BaseModel): + name: str + + class Admin(BaseModel): + name: str + + # Should raise ValueError due to duplicate 'name' field + with pytest.raises( + ValueError, match="Field name 'name' from Pydantic model 'Admin' conflicts" + ): + + @mcp.tool(unpack_pydantic_args=True) + def process(user: User, admin: Admin): + pass + + +def test_unpack_pydantic_args_collision_with_arg(): + mcp = FastMCP("test") + + class User(BaseModel): + name: str + + # Should raise ValueError due to duplicate 'name' field + with pytest.raises( + ValueError, match="Field name 'name' from Pydantic model 'User' conflicts" + ): + + @mcp.tool(unpack_pydantic_args=True) + def greet(name: str, user: User): + pass + + +def test_unpack_pydantic_args_default_factory(): + mcp = FastMCP("test") + + def generate_id(): + return "123" + + class Item(BaseModel): + id: str = Field(default_factory=generate_id) + name: str + + @mcp.tool(unpack_pydantic_args=True) + def create_item(item: Item) -> str: + return f"{item.id}:{item.name}" + + tool = mcp._tool_manager._tools["create_item"] + schema = tool.parameters + + # id should be required in schema because factory can't be represented statically + assert "id" in schema["required"] + assert "name" in schema["required"] + + import asyncio + + # User provides id + result = asyncio.run(tool.run({"id": "custom", "name": "test"})) + assert result.content[0].text == "custom:test" diff --git a/tests/tools/test_tool.py b/tests/tools/test_tool.py index 731bfed91..875c7a849 100644 --- a/tests/tools/test_tool.py +++ b/tests/tools/test_tool.py @@ -53,6 +53,7 @@ def add(a: int, b: int) -> int: "x-fastmcp-wrap-result": True, }, "fn": HasName("add"), + "unpack_pydantic_args": False, } ) @@ -100,6 +101,7 @@ async def fetch_data(url: str) -> str: "x-fastmcp-wrap-result": True, }, "fn": HasName("fetch_data"), + "unpack_pydantic_args": False, } ) @@ -133,6 +135,7 @@ def __call__(self, x: int, y: int) -> int: "type": "object", "x-fastmcp-wrap-result": True, }, + "unpack_pydantic_args": False, } ) @@ -166,6 +169,7 @@ async def __call__(self, x: int, y: int) -> int: "type": "object", "x-fastmcp-wrap-result": True, }, + "unpack_pydantic_args": False, } ) @@ -208,6 +212,7 @@ def create_user(user: UserInput, flag: bool) -> dict: }, "output_schema": {"additionalProperties": True, "type": "object"}, "fn": HasName("create_user"), + "unpack_pydantic_args": False, } ) @@ -269,6 +274,7 @@ def test_lambda(self): "required": ["x"], "type": "object", }, + "unpack_pydantic_args": False, } ) @@ -301,6 +307,7 @@ def add(_a: int, _b: int) -> int: "required": ["_a", "_b"], "type": "object", }, + "unpack_pydantic_args": False, } ) @@ -355,6 +362,7 @@ def add(self, x: int, y: int) -> int: "type": "object", "x-fastmcp-wrap-result": True, }, + "unpack_pydantic_args": False, } )