Skip to content

Commit cf2a504

Browse files
committed
Emit strict object-shaped input schemas for MCP tools; remove integer types; add schema test
1 parent 7f0fdd3 commit cf2a504

19 files changed

+422
-69
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ dependencies = [
88
"python-dotenv>=0.23.0",
99
"httpx>=0.28.1",
1010
"mcp[cli]>=1.9.3",
11-
"requests"
11+
"pydantic>=2",
12+
"requests>=2"
1213
]
1314

1415
[tool.pytest.ini_options]
15-
pythonpath = "src"
16+
pythonpath = ["src"]

src/main.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,33 @@
4343
for f in registration_functions:
4444
f(mcp)
4545

46+
# Temporary shim to normalize tool input schemas for Codex/Anthropic
47+
async def _list_tools_with_shim():
48+
tools = await original_list_tools()
49+
for tool in tools:
50+
schema = tool.inputSchema or {"type": "object", "properties": {}}
51+
if schema.get("properties") and list(schema["properties"].keys()) == ["args"]:
52+
schema = schema["properties"]["args"]
53+
schema.setdefault("type", "object")
54+
schema.setdefault("additionalProperties", False)
55+
56+
def _fix(node):
57+
if isinstance(node, dict):
58+
if node.get("type") == "integer":
59+
node["type"] = "number"
60+
for v in node.values():
61+
_fix(v)
62+
elif isinstance(node, list):
63+
for v in node:
64+
_fix(v)
65+
66+
_fix(schema)
67+
tool.inputSchema = schema
68+
return tools
69+
70+
original_list_tools = mcp.list_tools
71+
mcp.list_tools = _list_tools_with_shim
72+
4673
if __name__ == "__main__":
4774
# Run the server.
4875
mcp.run(transport=transport)

src/tool_args.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict, List, Type, Union, get_args, get_origin
4+
5+
from pydantic import BaseModel, Field, create_model, field_validator
6+
7+
8+
class ArgsBaseModel(BaseModel):
9+
"""Base class for tool argument models with strict object schema."""
10+
11+
model_config = {
12+
"extra": "forbid",
13+
"json_schema_extra": {"additionalProperties": False},
14+
}
15+
16+
17+
_MODEL_CACHE: Dict[type[BaseModel], type[BaseModel]] = {}
18+
19+
20+
def _convert_annotation(
21+
annotation: Any,
22+
field_name: str,
23+
validators: Dict[str, classmethod],
24+
) -> Any:
25+
origin = get_origin(annotation)
26+
if origin is None:
27+
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
28+
return get_args_model(annotation)
29+
if annotation is int:
30+
# Replace ints with floats and create validator to coerce to int
31+
@field_validator(field_name, mode="before")
32+
def _coerce(cls, v):
33+
return None if v is None else int(v)
34+
35+
validators[f"coerce_{field_name}"] = _coerce
36+
return float
37+
return annotation
38+
if origin in (list, List):
39+
inner = _convert_annotation(get_args(annotation)[0], field_name, validators)
40+
return List[inner]
41+
if origin is Union:
42+
converted = [
43+
_convert_annotation(arg, field_name, validators) for arg in get_args(annotation)
44+
]
45+
return Union[tuple(converted)]
46+
return annotation
47+
48+
49+
def get_args_model(model_cls: Type[BaseModel]) -> Type[BaseModel]:
50+
"""Create an Args model for the given request model."""
51+
if model_cls in _MODEL_CACHE:
52+
return _MODEL_CACHE[model_cls]
53+
54+
fields: Dict[str, tuple[Any, Any]] = {}
55+
validators: Dict[str, classmethod] = {}
56+
57+
for name, field in model_cls.model_fields.items():
58+
ann = _convert_annotation(field.annotation, name, validators)
59+
default = field.default if not field.is_required() else ...
60+
ge = le = None
61+
for meta in getattr(field, "metadata", []):
62+
if hasattr(meta, "ge"):
63+
ge = meta.ge
64+
if hasattr(meta, "le"):
65+
le = meta.le
66+
fields[name] = (
67+
ann,
68+
Field(
69+
default,
70+
description=getattr(field, "description", None),
71+
ge=ge,
72+
le=le,
73+
),
74+
)
75+
76+
ArgsModel = create_model(
77+
f"{model_cls.__name__}Args",
78+
__base__=ArgsBaseModel,
79+
__validators__=validators,
80+
**fields,
81+
)
82+
83+
_MODEL_CACHE[model_cls] = ArgsModel
84+
return ArgsModel
85+
86+
87+
from typing import Callable
88+
89+
90+
def tool_with_args(
91+
mcp,
92+
request_model: Type[BaseModel] | None = None,
93+
**decorator_kwargs: Any,
94+
) -> Callable:
95+
"""Decorator to wrap MCP tools with generated Args models."""
96+
97+
def decorator(func: Callable) -> Callable:
98+
if request_model is None:
99+
ArgsModel = ArgsBaseModel
100+
101+
async def inner(args: ArgsModel):
102+
return await func()
103+
else:
104+
ArgsModel = get_args_model(request_model)
105+
106+
async def inner(args: ArgsModel):
107+
model = request_model(**args.model_dump())
108+
return await func(model)
109+
110+
inner.__name__ = func.__name__
111+
inner.__doc__ = func.__doc__
112+
inner.__annotations__ = {
113+
"args": ArgsModel,
114+
"return": func.__annotations__.get("return", Any),
115+
}
116+
return mcp.tool(**decorator_kwargs)(inner)
117+
118+
return decorator

src/tools/account.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from api_connection import post
22
from models import AccountResponse
3+
from tool_args import tool_with_args
34

45
def register_account_tools(mcp):
56
# Read
6-
@mcp.tool(
7+
@tool_with_args(
8+
mcp,
79
annotations={
810
'title': 'Read account',
911
'readOnlyHint': True,

src/tools/ai.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
SyntaxCheckResponse,
1515
SearchResponse
1616
)
17+
from tool_args import tool_with_args
1718

1819
def register_ai_tools(mcp):
1920
# Get backtest initialization errors
20-
@mcp.tool(
21+
@tool_with_args(
22+
mcp,
23+
BasicFilesRequest,
2124
annotations={
2225
'title': 'Check initialization errors', 'readOnlyHint': True
2326
}
@@ -29,14 +32,20 @@ async def check_initialization_errors(
2932
return await post('/ai/tools/backtest-init', model)
3033

3134
# Complete code
32-
@mcp.tool(annotations={'title': 'Complete code', 'readOnlyHint': True})
35+
@tool_with_args(
36+
mcp,
37+
CodeCompletionRequest,
38+
annotations={'title': 'Complete code', 'readOnlyHint': True}
39+
)
3340
async def complete_code(
3441
model: CodeCompletionRequest) -> CodeCompletionResponse:
3542
"""Show the code completion for a specific text input."""
3643
return await post('/ai/tools/complete', model)
3744

3845
# Enchance error message
39-
@mcp.tool(
46+
@tool_with_args(
47+
mcp,
48+
ErrorEnhanceRequest,
4049
annotations={'title': 'Enhance error message', 'readOnlyHint': True}
4150
)
4251
async def enhance_error_message(
@@ -45,7 +54,9 @@ async def enhance_error_message(
4554
return await post('/ai/tools/error-enhance', model)
4655

4756
# Update code to PEP8
48-
@mcp.tool(
57+
@tool_with_args(
58+
mcp,
59+
PEP8ConvertRequest,
4960
annotations={'title': 'Update code to PEP8', 'readOnlyHint': True}
5061
)
5162
async def update_code_to_pep8(
@@ -54,13 +65,21 @@ async def update_code_to_pep8(
5465
return await post('/ai/tools/pep8-convert', model)
5566

5667
# Check syntax
57-
@mcp.tool(annotations={'title': 'Check syntax', 'readOnlyHint': True})
68+
@tool_with_args(
69+
mcp,
70+
BasicFilesRequest,
71+
annotations={'title': 'Check syntax', 'readOnlyHint': True}
72+
)
5873
async def check_syntax(model: BasicFilesRequest) -> SyntaxCheckResponse:
5974
"""Check the syntax of a code."""
6075
return await post('/ai/tools/syntax-check', model)
6176

6277
# Search
63-
@mcp.tool(annotations={'title': 'Search QuantConnect', 'readOnlyHint': True})
78+
@tool_with_args(
79+
mcp,
80+
SearchRequest,
81+
annotations={'title': 'Search QuantConnect', 'readOnlyHint': True}
82+
)
6483
async def search_quantconnect(model: SearchRequest) -> SearchResponse:
6584
"""Search for content in QuantConnect."""
6685
return await post('/ai/tools/search', model)

src/tools/backtests.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
BacktestReportGeneratingResponse,
2020
RestResponse
2121
)
22+
from tool_args import tool_with_args
2223

2324
def register_backtest_tools(mcp):
2425
# Create
25-
@mcp.tool(
26+
@tool_with_args(
27+
mcp,
28+
CreateBacktestRequest,
2629
annotations={
2730
'title': 'Create backtest',
2831
'destructiveHint': False
@@ -34,20 +37,30 @@ async def create_backtest(
3437
return await post('/backtests/create', model)
3538

3639
# Read statistics for a single backtest.
37-
@mcp.tool(annotations={'title': 'Read backtest', 'readOnlyHint': True})
40+
@tool_with_args(
41+
mcp,
42+
ReadBacktestRequest,
43+
annotations={'title': 'Read backtest', 'readOnlyHint': True}
44+
)
3845
async def read_backtest(model: ReadBacktestRequest) -> BacktestResponse:
3946
"""Read the results of a backtest."""
4047
return await post('/backtests/read', model)
4148

4249
# Read a summary of all the backtests.
43-
@mcp.tool(annotations={'title': 'List backtests', 'readOnlyHint': True})
50+
@tool_with_args(
51+
mcp,
52+
ListBacktestRequest,
53+
annotations={'title': 'List backtests', 'readOnlyHint': True}
54+
)
4455
async def list_backtests(
4556
model: ListBacktestRequest) -> BacktestSummaryResponse:
4657
"""List all the backtests for the project."""
4758
return await post('/backtests/list', model)
4859

4960
# Read the chart of a single backtest.
50-
@mcp.tool(
61+
@tool_with_args(
62+
mcp,
63+
ReadBacktestChartRequest,
5164
annotations={'title': 'Read backtest chart', 'readOnlyHint': True}
5265
)
5366
async def read_backtest_chart(
@@ -56,7 +69,9 @@ async def read_backtest_chart(
5669
return await post('/backtests/chart/read', model)
5770

5871
# Read the orders of a single backtest.
59-
@mcp.tool(
72+
@tool_with_args(
73+
mcp,
74+
ReadBacktestOrdersRequest,
6075
annotations={'title': 'Read backtest orders', 'readOnlyHint': True}
6176
)
6277
async def read_backtest_orders(
@@ -65,7 +80,9 @@ async def read_backtest_orders(
6580
return await post('/backtests/orders/read', model)
6681

6782
# Read the insights of a single backtest.
68-
@mcp.tool(
83+
@tool_with_args(
84+
mcp,
85+
ReadBacktestInsightsRequest,
6986
annotations={'title': 'Read backtest insights', 'readOnlyHint': True}
7087
)
7188
async def read_backtest_insights(
@@ -84,15 +101,19 @@ async def read_backtest_insights(
84101
# return await post('/backtests/read/report', model)
85102

86103
# Update
87-
@mcp.tool(
104+
@tool_with_args(
105+
mcp,
106+
UpdateBacktestRequest,
88107
annotations={'title': 'Update backtest', 'idempotentHint': True}
89108
)
90109
async def update_backtest(model: UpdateBacktestRequest) -> RestResponse:
91110
"""Update the name or note of a backtest."""
92111
return await post('/backtests/update', model)
93112

94113
# Delete
95-
@mcp.tool(
114+
@tool_with_args(
115+
mcp,
116+
DeleteBacktestRequest,
96117
annotations={'title': 'Delete backtest', 'idempotentHint': True}
97118
)
98119
async def delete_backtest(model: DeleteBacktestRequest) -> RestResponse:

src/tools/compile.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
CreateCompileResponse,
66
ReadCompileResponse
77
)
8+
from tool_args import tool_with_args
89

910
def register_compile_tools(mcp):
1011
# Create
11-
@mcp.tool(
12+
@tool_with_args(
13+
mcp,
14+
CreateCompileRequest,
1215
annotations={'title': 'Create compile', 'destructiveHint': False}
1316
)
1417
async def create_compile(
@@ -17,7 +20,11 @@ async def create_compile(
1720
return await post('/compile/create', model)
1821

1922
# Read
20-
@mcp.tool(annotations={'title': 'Read compile', 'readOnlyHint': True})
23+
@tool_with_args(
24+
mcp,
25+
ReadCompileRequest,
26+
annotations={'title': 'Read compile', 'readOnlyHint': True}
27+
)
2128
async def read_compile(model: ReadCompileRequest) -> ReadCompileResponse:
2229
"""Read a compile packet job result."""
2330
return await post('/compile/read', model)

0 commit comments

Comments
 (0)