Skip to content

Commit ac03915

Browse files
authored
[TRTLLM-9604][feat] DS R1 & V3.1 tool parser (#10010)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent 31bc14b commit ac03915

File tree

6 files changed

+767
-34
lines changed

6 files changed

+767
-34
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
{% if not add_generation_prompt is defined %}
2+
{% set add_generation_prompt = false %}
3+
{% endif %}
4+
{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %}
5+
{%- for message in messages %}
6+
{%- if message['role'] == 'system' %}
7+
{%- if ns.is_first_sp %}
8+
{% set ns.system_prompt = ns.system_prompt + message['content'] %}
9+
{% set ns.is_first_sp = false %}
10+
{%- else %}
11+
{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
12+
{%- endif %}
13+
{%- endif %}
14+
{%- endfor -%}
15+
16+
{#- Adapted from https://github.com/sgl-project/sglang/blob/main/examples/chat_template/tool_chat_template_deepseekr1.jinja #}
17+
{% if tools is defined and tools is not none %}
18+
{% set tool_ns = namespace(text='You are a helpful assistant with tool calling capabilities. '
19+
'When a tool call is needed, you MUST use the following format to issue the call:\n'
20+
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>FUNCTION_NAME\n'
21+
'```json\n{"param1": "value1", "param2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n\n'
22+
'Make sure the JSON is valid.'
23+
'## Tools\n\n### Function\n\nYou have the following functions available:\n\n') %}
24+
{% for tool in tools %}
25+
{% set tool_ns.text = tool_ns.text + '\n```json\n' + (tool | tojson) + '\n```\n' %}
26+
{% endfor %}
27+
{% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %}
28+
{% endif %}
29+
30+
{{- bos_token }}
31+
{{- ns.system_prompt }}
32+
{%- for message in messages %}
33+
{% set content = message['content'] %}
34+
{%- if message['role'] == 'user' %}
35+
{%- set ns.is_tool = false -%}
36+
{%- set ns.is_first = false -%}
37+
{%- set ns.is_last_user = true -%}
38+
{{'<|User|>' + content + '<|Assistant|>'}}
39+
{%- endif %}
40+
{%- if message['role'] == 'assistant' %}
41+
{% if '</think>' in content %}
42+
{% set content = content.split('</think>')[-1] %}
43+
{% endif %}
44+
{% endif %}
45+
{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
46+
{%- set ns.is_last_user = false -%}
47+
{%- if ns.is_tool %}
48+
{{- '<|tool▁outputs▁end|>'}}
49+
{%- endif %}
50+
{%- set ns.is_first = false %}
51+
{%- set ns.is_tool = false -%}
52+
{%- set ns.is_output_first = true %}
53+
{%- for tool in message['tool_calls'] %}
54+
{%- if not ns.is_first %}
55+
{%- if content is none %}
56+
{{- '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
57+
{%- else %}
58+
{{- content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
59+
{%- endif %}
60+
{%- set ns.is_first = true -%}
61+
{%- else %}
62+
{{- '\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
63+
{%- endif %}
64+
{%- endfor %}
65+
{{- '<|tool▁calls▁end|><|end▁of▁sentence|>'}}
66+
{%- endif %}
67+
{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%}
68+
{%- set ns.is_last_user = false -%}
69+
{%- if ns.is_tool %}
70+
{{- '<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}}
71+
{%- set ns.is_tool = false -%}
72+
{%- else %}
73+
{{- content + '<|end▁of▁sentence|>'}}
74+
{%- endif %}
75+
{%- endif %}
76+
{%- if message['role'] == 'tool' %}
77+
{%- set ns.is_last_user = false -%}
78+
{%- set ns.is_tool = true -%}
79+
{%- if ns.is_output_first %}
80+
{{- '<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
81+
{%- set ns.is_output_first = false %}
82+
{%- else %}
83+
{{- '\n<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
84+
{%- endif %}
85+
{%- endif %}
86+
{%- endfor -%}
87+
{% if ns.is_tool %}
88+
{{- '<|tool▁outputs▁end|>'}}
89+
{%- endif %}
90+
{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %}
91+
{{- '<|Assistant|>'}}
92+
{%- endif %}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
{% if not add_generation_prompt is defined %}
2+
{% set add_generation_prompt = false %}
3+
{% endif %}
4+
{% if not thinking is defined %}
5+
{% set thinking = false %}
6+
{% endif %}
7+
{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}
8+
{%- for message in messages %}
9+
{%- if message['role'] == 'system' %}
10+
{%- if ns.is_first_sp %}
11+
{% set ns.system_prompt = ns.system_prompt + message['content'] %}
12+
{% set ns.is_first_sp = false %}
13+
{%- else %}
14+
{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
15+
{%- endif %}
16+
{%- endif %}
17+
{%- endfor %}
18+
19+
{% if tools is defined and tools is not none %}
20+
{% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %}
21+
{% for tool in tools %}
22+
{% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %}
23+
{% endfor %}
24+
{% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %}
25+
{% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %}
26+
{% endif %}
27+
28+
{{ bos_token }}{{ ns.system_prompt }}
29+
{%- for message in messages %}
30+
{%- if message['role'] == 'user' %}
31+
{%- set ns.is_tool = false -%}
32+
{%- set ns.is_first = false -%}
33+
{%- set ns.is_last_user = true -%}
34+
{{'<|User|>' + message['content']}}
35+
{%- endif %}
36+
{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
37+
{%- if ns.is_last_user %}
38+
{{'<|Assistant|></think>'}}
39+
{%- endif %}
40+
{%- set ns.is_last_user = false -%}
41+
{%- set ns.is_first = false %}
42+
{%- set ns.is_tool = false -%}
43+
{%- for tool in message['tool_calls'] %}
44+
{%- if not ns.is_first %}
45+
{%- if message['content'] is none %}
46+
{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
47+
{%- else %}
48+
{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
49+
{%- endif %}
50+
{%- set ns.is_first = true -%}
51+
{%- else %}
52+
{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
53+
{%- endif %}
54+
{%- endfor %}
55+
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
56+
{%- endif %}
57+
{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}
58+
{%- if ns.is_last_user %}
59+
{{'<|Assistant|>'}}
60+
{%- if message['prefix'] is defined and message['prefix'] and thinking %}
61+
{{'<think>'}}
62+
{%- else %}
63+
{{'</think>'}}
64+
{%- endif %}
65+
{%- endif %}
66+
{%- set ns.is_last_user = false -%}
67+
{%- if ns.is_tool %}
68+
{{message['content'] + '<|end▁of▁sentence|>'}}
69+
{%- set ns.is_tool = false -%}
70+
{%- else %}
71+
{%- set content = message['content'] -%}
72+
{%- if '</think>' in content %}
73+
{%- set content = content.split('</think>', 1)[1] -%}
74+
{%- endif %}
75+
{{content + '<|end▁of▁sentence|>'}}
76+
{%- endif %}
77+
{%- endif %}
78+
{%- if message['role'] == 'tool' %}
79+
{%- set ns.is_last_user = false -%}
80+
{%- set ns.is_tool = true -%}
81+
{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
82+
{%- endif %}
83+
{%- endfor -%}
84+
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}
85+
{{'<|Assistant|>'}}
86+
{%- if not thinking %}
87+
{{'</think>'}}
88+
{%- else %}
89+
{{'<think>'}}
90+
{%- endif %}
91+
{% endif %}
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Adapted from https://github.com/sgl-project/sglang/blob/94e1251131ca27260cb0e8938aeb7b4a4e630b19/python/sglang/srt/function_call/deepseekv31_detector.py
2+
import json
3+
import re
4+
from typing import List
5+
6+
from tensorrt_llm.logger import logger
7+
from tensorrt_llm.serve.openai_protocol import ChatCompletionToolsParam as Tool
8+
from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser
9+
from tensorrt_llm.serve.tool_parser.core_types import (
10+
StreamingParseResult,
11+
StructureInfo,
12+
ToolCallItem,
13+
_GetInfoFunc,
14+
)
15+
16+
from .utils import is_complete_json
17+
18+
19+
class DeepSeekV31Parser(BaseToolParser):
20+
(
21+
"""Tool parser for DeepSeek V3 model function call format.
22+
23+
The DeepSeek V3 format uses special Unicode tokens to delimit function calls
24+
with JSON code blocks for arguments.
25+
26+
Format Structure:
27+
```
28+
<|tool▁calls▁begin|><|tool▁call▁begin|>{function_name}<|tool▁sep|>{json_arguments}<|tool▁calls▁end|><|end▁of▁sentence|>
29+
```
30+
Examples:
31+
```
32+
"""
33+
"""<|tool▁calls▁begin|>"""
34+
"""<|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Tokyo"}<|tool▁call▁end|>"""
35+
"""<|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|>"""
36+
"""<|tool▁calls▁end|><|end▁of▁sentence|>
37+
```
38+
39+
Key Components:
40+
- Tool Calls Section: Wrapped between `<|tool▁calls▁begin|>` and `<|tool▁calls▁end|>`
41+
- Individual Tool Call: Wrapped between `<|tool▁call▁begin|>` and `<|tool▁call▁end|>`
42+
- Function Declaration: `<|tool▁call▁begin|>{function_name}<|tool▁sep|>`
43+
- Arguments: JSON code block between `<|tool▁sep|>` and `<|tool▁call▁end|>`
44+
- Supports multiple tool calls
45+
46+
Reference: https://www.modelscope.cn/models/deepseek-ai/DeepSeek-V3.1
47+
"""
48+
)
49+
50+
def __init__(self):
51+
super().__init__()
52+
self.bot_token = "<|tool▁calls▁begin|>" # nosec B105
53+
self.eot_token = "<|tool▁calls▁end|>" # nosec B105
54+
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
55+
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)<|tool▁call▁end|>"
56+
self._last_arguments = ""
57+
self.current_tool_id = -1
58+
59+
def has_tool_call(self, text: str) -> bool:
60+
"""Check if the text contains a deepseek format tool call."""
61+
return self.bot_token in text
62+
63+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
64+
"""One-time parsing: Detects and parses tool calls in the provided text.
65+
66+
:param text: The complete text to parse.
67+
:param tools: List of available tools.
68+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
69+
"""
70+
idx = text.find(self.bot_token)
71+
normal_text = text[:idx].strip() if idx != -1 else text
72+
if self.bot_token not in text:
73+
return StreamingParseResult(normal_text=normal_text, calls=[])
74+
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
75+
calls = []
76+
try:
77+
for match_result in match_result_list:
78+
# Get function name
79+
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
80+
func_name = func_detail.group(1)
81+
func_args = func_detail.group(2)
82+
func_args = json.loads(func_args)
83+
# construct match_result for parse_base_json
84+
match_result = {"name": func_name, "parameters": func_args}
85+
calls.extend(self.parse_base_json(match_result, tools))
86+
return StreamingParseResult(normal_text=normal_text, calls=calls)
87+
except Exception as e:
88+
logger.error(f"Error in detect_and_parse: {e}")
89+
# return the normal text if parsing fails
90+
return StreamingParseResult(normal_text=text)
91+
92+
def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
93+
"""Streaming incremental parsing tool calls for DeepSeekV3 format."""
94+
self._buffer += new_text
95+
current_text = self._buffer
96+
97+
# Check if we have a tool call (either the start token or individual tool call)
98+
has_tool_call = self.bot_token in current_text or "<|tool▁call▁begin|>" in current_text
99+
100+
if not has_tool_call:
101+
if any(
102+
e_token.startswith(new_text)
103+
for e_token in [self.bot_token, "<|tool▁call▁begin|>"]
104+
):
105+
return StreamingParseResult()
106+
self._buffer = ""
107+
for e_token in [self.eot_token, "<|tool▁call▁end|>"]:
108+
if e_token in new_text:
109+
new_text = new_text.replace(e_token, "")
110+
return StreamingParseResult(normal_text=new_text)
111+
112+
if not hasattr(self, "_tool_indices"):
113+
self._tool_indices = self._get_tool_indices(tools)
114+
115+
calls: list[ToolCallItem] = []
116+
try:
117+
partial_match = re.search(
118+
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*?)(<|tool▁call▁end|>|$)",
119+
string=current_text,
120+
flags=re.DOTALL,
121+
)
122+
if partial_match:
123+
func_name = partial_match.group(1).strip()
124+
func_args_raw = partial_match.group(2).strip()
125+
is_tool_end = partial_match.group(3)
126+
127+
# Initialize state if this is the first tool call
128+
if self.current_tool_id == -1:
129+
self.current_tool_id = 0
130+
self.prev_tool_call_arr = []
131+
self.streamed_args_for_tool = [""]
132+
133+
# Ensure we have enough entries in our tracking arrays
134+
while len(self.prev_tool_call_arr) <= self.current_tool_id:
135+
self.prev_tool_call_arr.append({})
136+
while len(self.streamed_args_for_tool) <= self.current_tool_id:
137+
self.streamed_args_for_tool.append("")
138+
139+
if not self.current_tool_name_sent:
140+
calls.append(
141+
ToolCallItem(
142+
tool_index=self.current_tool_id,
143+
name=func_name,
144+
parameters="",
145+
)
146+
)
147+
self.current_tool_name_sent = True
148+
# Store the tool call info for serving layer completions endpoint
149+
self.prev_tool_call_arr[self.current_tool_id] = {
150+
"name": func_name,
151+
"arguments": {},
152+
}
153+
else:
154+
argument_diff = (
155+
func_args_raw[len(self._last_arguments) :]
156+
if func_args_raw.startswith(self._last_arguments)
157+
else func_args_raw
158+
)
159+
160+
if argument_diff:
161+
calls.append(
162+
ToolCallItem(
163+
tool_index=self.current_tool_id,
164+
name=None,
165+
parameters=argument_diff,
166+
)
167+
)
168+
self._last_arguments += argument_diff
169+
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
170+
171+
if is_complete_json(func_args_raw):
172+
# Update the stored arguments
173+
try:
174+
parsed_args = json.loads(func_args_raw)
175+
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = parsed_args
176+
except json.JSONDecodeError:
177+
pass
178+
179+
# Find the end of the current tool call and remove only that part from buffer
180+
if is_tool_end:
181+
# Remove the completed tool call from buffer, keep any remaining content
182+
self._buffer = current_text[partial_match.end(3) :]
183+
else:
184+
self._buffer = ""
185+
186+
result = StreamingParseResult(normal_text="", calls=calls)
187+
self.current_tool_id += 1
188+
self._last_arguments = ""
189+
self.current_tool_name_sent = False
190+
return result
191+
192+
return StreamingParseResult(normal_text="", calls=calls)
193+
194+
except Exception as e:
195+
logger.error(f"Error in parse_streaming_increment: {e}")
196+
return StreamingParseResult(normal_text=current_text)
197+
198+
def structure_info(self) -> _GetInfoFunc:
199+
return lambda name: StructureInfo(
200+
begin="<|tool▁call▁begin|>" + name + "<|tool▁sep|>",
201+
end="<|tool▁call▁end|>",
202+
trigger="<|tool▁call▁begin|>" + name + "<|tool▁sep|>",
203+
)

0 commit comments

Comments
 (0)