Skip to content

Commit e6b226a

Browse files
committed
fix(vLLM): Add tool calling support to VLLMClient.chat()
Fixes #4871 Previously, using GRPOTrainer with `vllm_mode="server"` raised a `NotImplementedError` when tools were passed to `VLLMClient.chat()`. This prevented users from using tool calling features with the vLLM server mode. Changes: - Remove the NotImplementedError check in VLLMClient.chat() - Add `tools` parameter to the HTTP request payload - Add `tools` field to ChatRequest model in vllm_serve.py - Pass tools to vLLM's chat() method on the server side
1 parent 3066891 commit e6b226a

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

trl/extras/vllm_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,6 @@ def chat(
352352
- `logprobs` (`list[list[float]]`):
353353
List of lists of log probabilities for each generated token.
354354
"""
355-
if tools is not None:
356-
raise NotImplementedError("Tool calling is not yet implemented in VLLMClient.chat().")
357355
if chat_template is not None:
358356
raise NotImplementedError("Custom chat templates are not yet implemented in VLLMClient.chat().")
359357

@@ -383,6 +381,7 @@ def chat(
383381
"structured_outputs_regex": structured_outputs_regex,
384382
"generation_kwargs": generation_kwargs or {},
385383
"chat_template_kwargs": chat_template_kwargs or {},
384+
"tools": tools,
386385
},
387386
)
388387
if response.status_code == 200:

trl/scripts/vllm_serve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,7 @@ class ChatRequest(BaseModel):
644644
structured_outputs_regex: str | None = None
645645
generation_kwargs: dict = field(default_factory=dict)
646646
chat_template_kwargs: dict = field(default_factory=dict)
647+
tools: list | None = None
647648

648649
class ChatResponse(BaseModel):
649650
prompt_ids: list[list[int]]
@@ -756,6 +757,7 @@ async def chat(request: ChatRequest):
756757
"messages": messages,
757758
"sampling_params": sampling_params,
758759
"chat_template_kwargs": request.chat_template_kwargs,
760+
"tools": request.tools,
759761
}
760762
connection.send({"type": "call", "method": "chat", "kwargs": kwargs})
761763

0 commit comments

Comments
 (0)