Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,53 @@ def get_engine() -> BaseEngine:
return _engine


def _coerce_tool_arguments(
arguments_json: str, tool_name: str, tools: list[dict] | None
) -> str:
"""
Coerce tool call arguments to match the tool schema.

If a schema field expects "string" but the model produced an object/array,
JSON-stringify the value. This fixes a common LLM failure mode where models
output raw JSON objects instead of JSON strings for file content, etc.
"""
if not tools:
return arguments_json

# Find the schema for this tool
schema = None
for tool in tools:
if isinstance(tool, dict) and tool.get("function", {}).get("name") == tool_name:
schema = tool["function"].get("parameters", {})
break

if not schema or "properties" not in schema:
return arguments_json

try:
arguments = json.loads(arguments_json)
except (json.JSONDecodeError, TypeError):
return arguments_json

if not isinstance(arguments, dict):
return arguments_json

properties = schema.get("properties", {})
changed = False

for key, value in arguments.items():
if key in properties:
expected_type = properties[key].get("type")
if expected_type == "string" and isinstance(value, (dict, list)):
arguments[key] = json.dumps(value, ensure_ascii=False, indent=2)
changed = True

if changed:
return json.dumps(arguments, ensure_ascii=False)

return arguments_json


def _parse_tool_calls_with_parser(
output_text: str, request: ChatCompletionRequest | None = None
) -> tuple[str, list | None]:
Expand Down Expand Up @@ -382,13 +429,16 @@ def _parse_tool_calls_with_parser(
_tool_parser_instance.reset()
result = _tool_parser_instance.extract_tool_calls(output_text, request_dict)
if result.tools_called:
tools = request_dict.get("tools") if request_dict else None
tool_calls = [
ToolCall(
id=tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
type="function",
function=FunctionCall(
name=tc["name"],
arguments=tc["arguments"],
arguments=_coerce_tool_arguments(
tc["arguments"], tc["name"], tools
),
),
)
for tc in result.tool_calls
Expand Down Expand Up @@ -1980,6 +2030,19 @@ async def stream_chat_completion(
if "tool_calls" in tool_result:
# Emit structured tool calls
tool_calls_detected = True
# Coerce arguments against tool schemas
tools = (
request.model_dump().get("tools")
if request and request.tools
else None
)
if tools:
for tc in tool_result["tool_calls"]:
fn = tc.get("function", {})
if "arguments" in fn and "name" in fn:
fn["arguments"] = _coerce_tool_arguments(
fn["arguments"], fn["name"], tools
)
chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
Expand Down Expand Up @@ -2030,6 +2093,9 @@ async def stream_chat_completion(
):
result = tool_parser.extract_tool_calls(tool_accumulated_text)
if result.tools_called:
tools = (
request.model_dump().get("tools") if request and request.tools else None
)
tool_chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
Expand All @@ -2043,7 +2109,9 @@ async def stream_chat_completion(
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"],
"arguments": _coerce_tool_arguments(
tc["arguments"], tc["name"], tools
),
},
}
for i, tc in enumerate(result.tool_calls)
Expand Down
Loading