Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/fastmcp/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,7 @@ def tool(
exclude_args: list[str] | None = None,
meta: dict[str, Any] | None = None,
enabled: bool | None = None,
unpack_pydantic_args: bool = False,
) -> FunctionTool: ...

@overload
Expand All @@ -1397,6 +1398,7 @@ def tool(
exclude_args: list[str] | None = None,
meta: dict[str, Any] | None = None,
enabled: bool | None = None,
unpack_pydantic_args: bool = False,
) -> Callable[[AnyFunction], FunctionTool]: ...

def tool(
Expand All @@ -1413,6 +1415,7 @@ def tool(
exclude_args: list[str] | None = None,
meta: dict[str, Any] | None = None,
enabled: bool | None = None,
unpack_pydantic_args: bool = False,
) -> Callable[[AnyFunction], FunctionTool] | FunctionTool:
"""Decorator to register a tool.

Expand Down Expand Up @@ -1501,6 +1504,7 @@ def my_tool(x: int) -> str:
meta=meta,
serializer=self._tool_serializer,
enabled=enabled,
unpack_pydantic_args=unpack_pydantic_args,
)
self.add_tool(tool)
return tool
Expand Down Expand Up @@ -1534,6 +1538,7 @@ def my_tool(x: int) -> str:
exclude_args=exclude_args,
meta=meta,
enabled=enabled,
unpack_pydantic_args=unpack_pydantic_args,
)

def add_resource(self, resource: Resource) -> Resource:
Expand Down
86 changes: 83 additions & 3 deletions src/fastmcp/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def from_function(
serializer: ToolResultSerializerType | None = None,
meta: dict[str, Any] | None = None,
enabled: bool | None = None,
unpack_pydantic_args: bool = False,
) -> FunctionTool:
"""Create a Tool from a function."""
return FunctionTool.from_function(
Expand All @@ -203,6 +204,7 @@ def from_function(
serializer=serializer,
meta=meta,
enabled=enabled,
unpack_pydantic_args=unpack_pydantic_args,
)

async def run(self, arguments: dict[str, Any]) -> ToolResult:
Expand Down Expand Up @@ -254,6 +256,8 @@ def from_tool(

class FunctionTool(Tool):
fn: Callable[..., Any]
unpack_pydantic_args: bool = False
unpacked_models_map: dict | None = None

@classmethod
def from_function(
Expand All @@ -270,6 +274,7 @@ def from_function(
serializer: ToolResultSerializerType | None = None,
meta: dict[str, Any] | None = None,
enabled: bool | None = None,
unpack_pydantic_args: bool = False,
) -> FunctionTool:
"""Create a Tool from a function."""
if exclude_args and fastmcp.settings.deprecation_warnings:
Expand All @@ -283,7 +288,11 @@ def from_function(
stacklevel=2,
)

parsed_fn = ParsedFunction.from_function(fn, exclude_args=exclude_args)
parsed_fn = ParsedFunction.from_function(
fn,
exclude_args=exclude_args,
unpack_pydantic_args=unpack_pydantic_args,
)

if name is None and parsed_fn.name == "<lambda>":
raise ValueError("You must provide a name for lambda functions")
Expand Down Expand Up @@ -325,6 +334,8 @@ def from_function(
serializer=serializer,
meta=meta,
enabled=enabled if enabled is not None else True,
unpack_pydantic_args=unpack_pydantic_args,
unpacked_models_map=parsed_fn.unpacked_models_map,
)

async def run(self, arguments: dict[str, Any]) -> ToolResult:
Expand All @@ -333,6 +344,28 @@ async def run(self, arguments: dict[str, Any]) -> ToolResult:

arguments = arguments.copy()

# If unpacking is enabled, re-assemble the Pydantic models from the arguments
if self.unpack_pydantic_args and self.unpacked_models_map:
assembled_args = {}
consumed_keys = set()

for arg_name, model_cls in self.unpacked_models_map.items():
model_fields = model_cls.model_fields.keys()
model_kwargs = {}
for field_name in model_fields:
if field_name in arguments:
model_kwargs[field_name] = arguments[field_name]
consumed_keys.add(field_name)

assembled_args[arg_name] = model_cls(**model_kwargs)

# Add the remaining non-pydantic arguments
for key, value in arguments.items():
if key not in consumed_keys:
assembled_args[key] = value

arguments = assembled_args

context_kwarg = find_kwarg_by_type(self.fn, kwarg_type=Context)
if context_kwarg and context_kwarg not in arguments:
arguments[context_kwarg] = get_context()
Expand Down Expand Up @@ -400,6 +433,7 @@ class ParsedFunction:
description: str | None
input_schema: dict[str, Any]
output_schema: dict[str, Any] | None
unpacked_models_map: dict[str, Any] | None = None

@classmethod
def from_function(
Expand All @@ -408,6 +442,7 @@ def from_function(
exclude_args: list[str] | None = None,
validate: bool = True,
wrap_non_object_output_schema: bool = True,
unpack_pydantic_args: bool = False,
) -> ParsedFunction:
from fastmcp.server.context import Context

Expand Down Expand Up @@ -453,12 +488,56 @@ def from_function(
if exclude_args:
prune_params.extend(exclude_args)

unpacked_models_map = {}
fn_for_schema = fn

if unpack_pydantic_args:
import pydantic

original_sig = inspect.signature(fn)
new_params = []
has_pydantic_model = False

for param in original_sig.parameters.values():
# Check if the parameter type is a Pydantic model
if isinstance(param.annotation, type) and issubclass(param.annotation, pydantic.BaseModel):
has_pydantic_model = True
unpacked_models_map[param.name] = param.annotation
# Unpack the model's fields into new parameters
for field_name, field in param.annotation.model_fields.items():
# Create a new parameter for each field
new_param = inspect.Parameter(
name=field_name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=field.default if not field.is_required() else inspect.Parameter.empty,
annotation=field.annotation,
)
new_params.append(new_param)
else:
new_params.append(param)

if has_pydantic_model:
# Sort parameters to ensure non-defaults come first
required_params = [p for p in new_params if p.default == inspect.Parameter.empty]
optional_params = [p for p in new_params if p.default != inspect.Parameter.empty]
new_params = required_params + optional_params

# Create a new function with the unpacked signature for schema generation
def placeholder_fn(*args, **kwargs): ...

new_sig = original_sig.replace(parameters=new_params)
new_annotations = {p.name: p.annotation for p in new_params if p.annotation != inspect.Parameter.empty}

placeholder_fn.__signature__ = new_sig
placeholder_fn.__annotations__ = new_annotations
fn_for_schema = placeholder_fn

# Create a function without excluded parameters in annotations
# This prevents Pydantic from trying to serialize non-serializable types
# before we can exclude them in compress_schema
fn_for_typeadapter = fn
fn_for_typeadapter = fn_for_schema
if prune_params:
fn_for_typeadapter = create_function_without_params(fn, prune_params)
fn_for_typeadapter = create_function_without_params(fn_for_schema, prune_params)

input_type_adapter = get_cached_typeadapter(fn_for_typeadapter)
input_schema = input_type_adapter.json_schema()
Expand Down Expand Up @@ -535,6 +614,7 @@ def from_function(
description=fn_doc,
input_schema=input_schema,
output_schema=output_schema or None,
unpacked_models_map=unpacked_models_map or None,
)


Expand Down
66 changes: 66 additions & 0 deletions tests/test_unpack_pydantic_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from pydantic import BaseModel, Field
from fastmcp import FastMCP


class User(BaseModel):
name: str = Field(description="The user's name")
age: int = Field(description="The user's age")


def test_unpack_pydantic_args():
mcp = FastMCP("test")

@mcp.tool(unpack_pydantic_args=True)
def greet_user(user: User, greeting: str = "Hello") -> str:
return f"{greeting}, {user.name}! You are {user.age} years old."

tool = mcp._tool_manager._tools["greet_user"]

# Check schema
schema = tool.parameters
assert "name" in schema["properties"]
assert "age" in schema["properties"]
assert "greeting" in schema["properties"]
assert "user" not in schema["properties"]

# Check required fields
assert "name" in schema["required"]
assert "age" in schema["required"]
assert "greeting" not in schema.get("required", [])

# Run tool
import asyncio

result = asyncio.run(tool.run({"name": "Alice", "age": 30, "greeting": "Hi"}))

assert result.content[0].text == "Hi, Alice! You are 30 years old."


def test_unpack_pydantic_args_nested():
mcp = FastMCP("test")

class Address(BaseModel):
city: str
zipcode: str

class UserWithAddress(BaseModel):
name: str
address: Address

# This feature currently only unpacks top-level Pydantic models.
# Nested models inside Pydantic models are kept as is (Pydantic handles them).

@mcp.tool(unpack_pydantic_args=True)
def process_address(address: Address) -> str:
return f"{address.city} {address.zipcode}"

tool = mcp._tool_manager._tools["process_address"]
schema = tool.parameters

assert "city" in schema["properties"]
assert "zipcode" in schema["properties"]

import asyncio

result = asyncio.run(tool.run({"city": "New York", "zipcode": "10001"}))
assert result.content[0].text == "New York 10001"
8 changes: 8 additions & 0 deletions tests/tools/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def add(a: int, b: int) -> int:
"x-fastmcp-wrap-result": True,
},
"fn": HasName("add"),
"unpack_pydantic_args": False,
}
)

Expand Down Expand Up @@ -100,6 +101,7 @@ async def fetch_data(url: str) -> str:
"x-fastmcp-wrap-result": True,
},
"fn": HasName("fetch_data"),
"unpack_pydantic_args": False,
}
)

Expand Down Expand Up @@ -133,6 +135,7 @@ def __call__(self, x: int, y: int) -> int:
"type": "object",
"x-fastmcp-wrap-result": True,
},
"unpack_pydantic_args": False,
}
)

Expand Down Expand Up @@ -166,6 +169,7 @@ async def __call__(self, x: int, y: int) -> int:
"type": "object",
"x-fastmcp-wrap-result": True,
},
"unpack_pydantic_args": False,
}
)

Expand Down Expand Up @@ -208,6 +212,7 @@ def create_user(user: UserInput, flag: bool) -> dict:
},
"output_schema": {"additionalProperties": True, "type": "object"},
"fn": HasName("create_user"),
"unpack_pydantic_args": False,
}
)

Expand Down Expand Up @@ -269,6 +274,7 @@ def test_lambda(self):
"required": ["x"],
"type": "object",
},
"unpack_pydantic_args": False,
}
)

Expand Down Expand Up @@ -301,6 +307,7 @@ def add(_a: int, _b: int) -> int:
"required": ["_a", "_b"],
"type": "object",
},
"unpack_pydantic_args": False,
}
)

Expand Down Expand Up @@ -355,6 +362,7 @@ def add(self, x: int, y: int) -> int:
"type": "object",
"x-fastmcp-wrap-result": True,
},
"unpack_pydantic_args": False,
}
)

Expand Down
Loading