diff --git a/docs/servers/middleware.mdx b/docs/servers/middleware.mdx index 5ef8d8fe1..f1795e76b 100644 --- a/docs/servers/middleware.mdx +++ b/docs/servers/middleware.mdx @@ -484,6 +484,35 @@ mcp.add_middleware(StructuredLoggingMiddleware(include_payloads=True)) The built-in versions include payload logging, structured JSON output, custom logger support, payload size limits, and operation-specific hooks for granular control. +### Schema Dereference Middleware + +Some MCP clients have limited support for JSON Schema `$ref` references in tool and resource template schemas. This can lead to null parameters or validation failures when a client cannot resolve references during tool invocation. + + +The schema dereference middleware flattens internal `$ref` references (those pointing to `#/$defs/...`) in listing responses, so clients receive inline schemas that are easier to consume. + + +```python +from fastmcp import FastMCP +from fastmcp.server.middleware.schema_dereference import SchemaDereferenceMiddleware + +mcp = FastMCP("MyServer") +mcp.add_middleware(SchemaDereferenceMiddleware()) +``` + +What it does +- Inlines internal `$ref` to `#/$defs/...` for: + - on_list_tools: `Tool.parameters` + - on_list_resource_templates: `ResourceTemplate.parameters` +- Handles common schema shapes: + - Properties, arrays (`items`), maps (`additionalProperties`) + - Composition keywords (`allOf`, `oneOf`, `anyOf`) + - Transitive and reused definitions + + +Enable this middleware when serving tools or templates to clients that cannot resolve JSON Schema `$ref` internally. It is safe to use with clients that fully support `$ref` as well. + + ### Rate Limiting Middleware Rate limiting is essential for protecting your server from abuse, ensuring fair resource usage, and maintaining performance under load. FastMCP includes sophisticated rate limiting middleware at `fastmcp.server.middleware.rate_limiting`. diff --git a/src/fastmcp/server/middleware/schema_dereference.py b/src/fastmcp/server/middleware/schema_dereference.py new file mode 100644 index 000000000..3516abcb4 --- /dev/null +++ b/src/fastmcp/server/middleware/schema_dereference.py @@ -0,0 +1,201 @@ +from copy import deepcopy +from typing import Any + +import mcp.types as mt + +from fastmcp.server.middleware.middleware import CallNext, Middleware, MiddlewareContext + + +def _detect_self_reference(schema: dict) -> bool: + """ + Detect if the schema contains self-referencing definitions. + Args: + schema: The JSON schema to check + Returns: + True if self-referencing is detected + """ + defs = schema.get("$defs", {}) + + def find_refs_in_value(value: Any, parent_def: str) -> bool: + """Check if a value contains a reference to its parent definition.""" + if isinstance(value, dict): + if "$ref" in value: + ref_path = value["$ref"] + # Check if this references the parent definition + if ref_path == f"#/$defs/{parent_def}": + return True + # Check all values in the dict + for v in value.values(): + if find_refs_in_value(v, parent_def): + return True + elif isinstance(value, list): + # Check all items in the list + for item in value: + if find_refs_in_value(item, parent_def): + return True + return False + + # Check each definition for self-reference + for def_name, def_content in defs.items(): + if find_refs_in_value(def_content, def_name): + # Self-reference detected, return original schema + return True + + return False + + +def dereference_json_schema(schema: dict) -> dict: + """ + Dereference a JSON schema by resolving $ref references while preserving $defs only when corner cases occur. + This function flattens schema properties by: + 1. Check for self-reference - if found, return original schema with $defs + 2. When encountering $refs in properties, resolve them on-demand + 3. Track visited definitions globally to prevent circular expansion + 4. Only preserve original $defs if corner cases are encountered: + - Self-reference detected + - Circular references between definitions + - Reference not found in $defs + Args: + schema: The JSON schema to flatten + Returns: + Schema with references resolved in properties, keeping $defs only when corner cases occur + """ + # Step 1: Check for self-reference + if _detect_self_reference(schema): + # Self-referencing detected, return original schema with $defs + return schema + + # Make a deep copy to work with + result = deepcopy(schema) + + # Keep original $defs for potential corner cases + defs = deepcopy(schema.get("$defs", {})) + + # Track corner cases that require preserving $defs + corner_cases_detected = { + "circular_ref": False, + "ref_not_found": False, + } + + # Step 2: Define resolution function that tracks visits globally and corner cases + def resolve_refs_in_value(value: Any, depth: int, visiting: set[str]) -> Any: + """ + Recursively resolve $refs in a value. + Args: + value: The value to process + depth: Current depth in resolution + visiting: Set of definitions currently being resolved (for cycle detection) + Returns: + Value with $refs resolved (or kept if corner cases occur) + """ + if isinstance(value, dict): + if "$ref" in value: + ref_path = value["$ref"] + + # Only handle internal references to $defs + if isinstance(ref_path, str) and ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] + + # Check for circular reference + if def_name in visiting: + # Circular reference detected, keep the $ref + corner_cases_detected["circular_ref"] = True + return value + + if def_name in defs: + # Add to visiting set + visiting.add(def_name) + + # Get the definition and resolve any refs within it + resolved = resolve_refs_in_value( + deepcopy(defs[def_name]), depth + 1, visiting + ) + + # Remove from visiting set + visiting.remove(def_name) + + # Merge resolved definition with additional properties + # Additional properties from the original object take precedence + for key, val in value.items(): + if key != "$ref": + resolved[key] = val + + return resolved + else: + # Definition not found, keep the $ref + corner_cases_detected["ref_not_found"] = True + return value + else: + # External ref or other type - keep as is + return value + else: + # Regular dict - process all values + return { + key: resolve_refs_in_value(val, depth, visiting) + for key, val in value.items() + } + elif isinstance(value, list): + # Process each item in the list + return [resolve_refs_in_value(item, depth, visiting) for item in value] + else: + # Primitive value - return as is + return value + + # Step 3: Process main schema properties with shared visiting set + for key, value in result.items(): + if key != "$defs": + # Each top-level property gets its own visiting set + # This allows the same definition to be used in different contexts + result[key] = resolve_refs_in_value(value, 0, set()) + + # Step 4: Conditionally preserve $defs based on corner cases + if any(corner_cases_detected.values()): + # Corner case detected, preserve original $defs + if "$defs" in schema: # Only add if original schema had $defs + result["$defs"] = defs + else: + # No corner cases, remove $defs if it exists + result.pop("$defs", None) + + return result + + +class SchemaDereferenceMiddleware(Middleware): + """Middleware that dereferences $ref in schemas for tools, resource templates. + + Applies to list handlers so that clients like Claude Desktop receive flattened schemas + without $ref in properties, preventing null parameter values. + """ + + async def on_list_tools( + self, + context: MiddlewareContext[mt.ListToolsRequest], + call_next: CallNext[mt.ListToolsRequest, list], + ) -> list: + tools = await call_next(context) + flattened = [] + for tool in tools: + params = getattr(tool, "parameters", None) + update: dict[str, Any] = {} + if isinstance(params, dict): + update["parameters"] = dereference_json_schema(params) + if update: + tool = tool.model_copy(update=update) + flattened.append(tool) + return flattened + + async def on_list_resource_templates( + self, + context: MiddlewareContext[mt.ListResourceTemplatesRequest], + call_next: CallNext[mt.ListResourceTemplatesRequest, list], + ) -> list: + templates = await call_next(context) + flattened = [] + for template in templates: + params = getattr(template, "parameters", None) + if isinstance(params, dict): + template = template.model_copy( + update={"parameters": dereference_json_schema(params)} + ) + flattened.append(template) + return flattened diff --git a/tests/server/middleware/test_schema_dereference.py b/tests/server/middleware/test_schema_dereference.py new file mode 100644 index 000000000..43ac38733 --- /dev/null +++ b/tests/server/middleware/test_schema_dereference.py @@ -0,0 +1,60 @@ +import json +from enum import Enum + +from fastmcp import Client, FastMCP +from fastmcp.server.middleware.schema_dereference import ( + SchemaDereferenceMiddleware, +) + + +class ColorEnum(str, Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +class TestSchemaDereferenceMiddleware: + async def test_dereference_enum_in_tool_parameters(self): + mcp = FastMCP("SchemaDereferenceTest") + mcp.add_middleware(SchemaDereferenceMiddleware()) + + @mcp.tool + def choose_color(color: ColorEnum) -> str: + return color.value + + async with Client(mcp) as client: + tools = await client.list_tools() + + tool = next(t for t in tools if t.name == "choose_color") + schema = tool.inputSchema + + # Ensure $ref was inlined and $defs removed for simple enum case + assert "$ref" not in json.dumps(schema) + assert "$defs" not in schema + + assert "properties" in schema and "color" in schema["properties"] + color_schema = schema["properties"]["color"] + assert color_schema.get("enum") == ["red", "green", "blue"] + assert color_schema.get("type") == "string" + + async def test_dereference_enum_in_resource_template_parameters(self): + mcp = FastMCP("SchemaDereferenceTemplateTest") + mcp.add_middleware(SchemaDereferenceMiddleware()) + + @mcp.resource("color://{color}") + def color_resource(color: ColorEnum) -> str: + return color.value + + # Use internal list to inspect template parameters after middleware + templates = await mcp._list_resource_templates() + assert len(templates) == 1 + params = templates[0].parameters + + # Ensure $ref was inlined and $defs removed for simple enum case + assert "$ref" not in json.dumps(params) + assert "$defs" not in params + + assert "properties" in params and "color" in params["properties"] + color_schema = params["properties"]["color"] + assert color_schema.get("enum") == ["red", "green", "blue"] + assert color_schema.get("type") == "string" diff --git a/tests/server/middleware/test_schema_dereference_function.py b/tests/server/middleware/test_schema_dereference_function.py new file mode 100644 index 000000000..26363661f --- /dev/null +++ b/tests/server/middleware/test_schema_dereference_function.py @@ -0,0 +1,272 @@ +import copy + +from fastmcp.server.middleware.schema_dereference import ( + dereference_json_schema, +) + + +def test_dereference_simple_enum_property_inlines_and_removes_defs(): + schema = { + "$defs": {"Color": {"type": "string", "enum": ["red", "green", "blue"]}}, + "type": "object", + "properties": {"color": {"$ref": "#/$defs/Color"}}, + "required": ["color"], + } + + result = dereference_json_schema(copy.deepcopy(schema)) + + assert "$defs" not in result + assert "$ref" not in str(result) + assert result["properties"]["color"]["enum"] == ["red", "green", "blue"] + assert result["properties"]["color"]["type"] == "string" + assert result["required"] == ["color"] + + +def test_dereference_merges_additional_properties(): + schema = { + "$defs": { + "Num": {"type": "number"}, + }, + "type": "object", + "properties": { + "n": { + "$ref": "#/$defs/Num", + "description": "a number", + "minimum": 0, + } + }, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + + n_schema = result["properties"]["n"] + assert n_schema["type"] == "number" + assert n_schema["description"] == "a number" + assert n_schema["minimum"] == 0 + assert "$defs" not in result + + +def test_dereference_allof_inlines_refs(): + schema = { + "$defs": { + "S": {"type": "string"}, + }, + "type": "object", + "properties": { + "val": { + "allOf": [ + {"$ref": "#/$defs/S"}, + {"maxLength": 10}, + ] + } + }, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + allof = result["properties"]["val"]["allOf"] + assert any(item.get("type") == "string" for item in allof) + assert any(item.get("maxLength") == 10 for item in allof) + assert "$defs" not in result + + +def test_dereference_keeps_schema_on_self_reference(): + schema = { + "$defs": { + "Node": { + "type": "object", + "properties": {"next": {"$ref": "#/$defs/Node"}}, + } + }, + "type": "object", + "properties": {"head": {"$ref": "#/$defs/Node"}}, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + # Self-reference: schema should be returned as-is to avoid expansion + assert result == schema + + +def test_dereference_keeps_ref_and_preserves_defs_on_circular_refs(): + schema = { + "$defs": { + "A": {"$ref": "#/$defs/B"}, + "B": {"$ref": "#/$defs/A"}, + }, + "type": "object", + "properties": {"x": {"$ref": "#/$defs/A"}}, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + # Circular: keep $ref and keep $defs + assert result.get("$defs") == schema["$defs"] + assert result["properties"]["x"]["$ref"] == "#/$defs/A" + + +def test_dereference_missing_ref_keeps_ref(): + schema = { + "$defs": {"Something": {"type": "integer"}}, + "type": "object", + "properties": {"x": {"$ref": "#/$defs/Missing"}}, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + # Missing ref: keep $ref, do not drop existing $defs + assert result.get("$defs") == schema["$defs"] + assert result["properties"]["x"]["$ref"] == "#/$defs/Missing" + + +def test_dereference_transitive_refs_are_inlined(): + schema = { + "$defs": { + "A": {"type": "object", "properties": {"b": {"$ref": "#/$defs/B"}}}, + "B": {"type": "string"}, + }, + "type": "object", + "properties": {"x": {"$ref": "#/$defs/A"}}, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + assert "$defs" not in result + x_schema = result["properties"]["x"] + assert x_schema["type"] == "object" + assert x_schema["properties"]["b"]["type"] == "string" + + +def test_dereference_array_items_ref_inlined(): + schema = { + "$defs": {"S": {"type": "string"}}, + "type": "object", + "properties": {"arr": {"type": "array", "items": {"$ref": "#/$defs/S"}}}, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + assert "$defs" not in result + assert result["properties"]["arr"]["items"]["type"] == "string" + + +def test_dereference_additional_properties_ref_inlined(): + schema = { + "$defs": {"S": {"type": "string"}}, + "type": "object", + "properties": { + "map": { + "type": "object", + "additionalProperties": {"$ref": "#/$defs/S"}, + } + }, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + assert "$defs" not in result + assert ( + result["properties"]["map"]["additionalProperties"]["type"] == "string" + and result["properties"]["map"]["type"] == "object" + ) + + +def test_dereference_oneof_and_anyof_refs_inlined(): + schema = { + "$defs": { + "S": {"type": "string"}, + "N": {"type": "number"}, + }, + "type": "object", + "properties": { + "v1": {"oneOf": [{"$ref": "#/$defs/S"}, {"$ref": "#/$defs/N"}]}, + "v2": {"anyOf": [{"$ref": "#/$defs/S"}, {"type": "boolean"}]}, + }, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + assert "$defs" not in result + oneof = result["properties"]["v1"]["oneOf"] + anyof = result["properties"]["v2"]["anyOf"] + assert any(item.get("type") == "string" for item in oneof) + assert any(item.get("type") == "number" for item in oneof) + assert any(item.get("type") == "string" for item in anyof) + assert any(item.get("type") == "boolean" for item in anyof) + + +def test_dereference_reused_def_across_properties_inlined_independently(): + schema = { + "$defs": {"S": {"type": "string"}}, + "type": "object", + "properties": { + "p1": {"$ref": "#/$defs/S"}, + "p2": {"$ref": "#/$defs/S"}, + }, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + assert "$defs" not in result + assert result["properties"]["p1"]["type"] == "string" + assert result["properties"]["p2"]["type"] == "string" + + +def test_dereference_mixed_present_and_missing_refs_inlines_present_and_preserves_defs(): + schema = { + "$defs": {"S": {"type": "string"}}, + "type": "object", + "properties": { + "ok": {"$ref": "#/$defs/S"}, + "bad": {"$ref": "#/$defs/Missing"}, + }, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + # Because of missing ref, defs are preserved + assert "$defs" in result + assert result["properties"]["ok"]["type"] == "string" + assert result["properties"]["bad"]["$ref"] == "#/$defs/Missing" + + +def test_dereference_self_reference_in_def_array_items_returns_original(): + schema = { + "$defs": {"Node": {"type": "array", "items": {"$ref": "#/$defs/Node"}}}, + "type": "object", + "properties": {"head": {"$ref": "#/$defs/Node"}}, + } + + original = copy.deepcopy(schema) + result = dereference_json_schema(schema) + assert result == original + + +def test_dereference_nested_defs_with_allof_are_inlined(): + schema = { + "$defs": { + "Base": {"type": "object", "properties": {"a": {"type": "integer"}}}, + "Ext": { + "allOf": [ + {"$ref": "#/$defs/Base"}, + {"type": "object", "properties": {"b": {"type": "string"}}}, + ] + }, + }, + "type": "object", + "properties": {"x": {"$ref": "#/$defs/Ext"}}, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + assert "$defs" not in result + x_schema = result["properties"]["x"] + assert "allOf" in x_schema + assert any( + item.get("properties", {}).get("a", {}).get("type") == "integer" + for item in x_schema["allOf"] + ) + assert any( + item.get("properties", {}).get("b", {}).get("type") == "string" + for item in x_schema["allOf"] + ) + + +def test_dereference_removes_empty_defs_section(): + schema = { + "$defs": {}, + "type": "object", + "properties": {"n": {"type": "number"}}, + } + + result = dereference_json_schema(copy.deepcopy(schema)) + assert "$defs" not in result