Skip to content

Commit c417ab4

Browse files
fix(tool-calling): parser improvements and non-streaming fix
- Add alternate format detection for <tool_name {json}> patterns - Handle cases where small LLMs use tool name as tag instead of <tool_call> - Use non-streaming generate() to avoid SIGABRT crash - Support direct arguments format in alternate format These fixes improve compatibility with smaller models like SmolLM2.
1 parent 65474bd commit c417ab4

File tree

2 files changed

+73
-13
lines changed

2 files changed

+73
-13
lines changed

sdk/runanywhere-react-native/packages/core/cpp/bridges/ToolCallingBridge.cpp

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <nlohmann/json.hpp>
1212
#include <sstream>
1313
#include <cstring>
14+
#include <cctype>
1415

1516
using json = nlohmann::json;
1617

@@ -131,9 +132,52 @@ std::string ToolCallingBridge::formatToolsPrompt(const std::string& toolsJson) {
131132
}
132133

133134
std::string ToolCallingBridge::parseToolCall(const std::string& llmOutput) {
134-
// Find tool call tags
135+
// Find tool call tags (primary format)
135136
size_t tagStart = llmOutput.find(TOOL_CALL_START_TAG);
136137

138+
// If no <tool_call> tag, check for alternate format: <tool_name {json}>
139+
// Some smaller models use the tool name as the tag instead of <tool_call>
140+
bool usingAlternateFormat = false;
141+
std::string alternateToolName;
142+
143+
if (tagStart == std::string::npos) {
144+
// Look for pattern: <word followed by space/{ and JSON
145+
size_t ltPos = 0;
146+
while ((ltPos = llmOutput.find('<', ltPos)) != std::string::npos) {
147+
size_t nameStart = ltPos + 1;
148+
size_t nameEnd = nameStart;
149+
while (nameEnd < llmOutput.size() &&
150+
(std::isalnum(static_cast<unsigned char>(llmOutput[nameEnd])) ||
151+
llmOutput[nameEnd] == '_' || llmOutput[nameEnd] == '-')) {
152+
nameEnd++;
153+
}
154+
155+
if (nameEnd > nameStart) {
156+
std::string tagName = llmOutput.substr(nameStart, nameEnd - nameStart);
157+
158+
// Skip common HTML-like tags
159+
if (tagName != "p" && tagName != "br" && tagName != "div" &&
160+
tagName != "span" && tagName != "a" && tagName.length() > 2) {
161+
162+
size_t jsonStartCheck = nameEnd;
163+
while (jsonStartCheck < llmOutput.size() &&
164+
(llmOutput[jsonStartCheck] == ' ' || llmOutput[jsonStartCheck] == '\t' ||
165+
llmOutput[jsonStartCheck] == '\n' || llmOutput[jsonStartCheck] == '>')) {
166+
jsonStartCheck++;
167+
}
168+
169+
if (jsonStartCheck < llmOutput.size() && llmOutput[jsonStartCheck] == '{') {
170+
tagStart = ltPos;
171+
usingAlternateFormat = true;
172+
alternateToolName = tagName;
173+
break;
174+
}
175+
}
176+
}
177+
ltPos++;
178+
}
179+
}
180+
137181
if (tagStart == std::string::npos) {
138182
// No tool call found - return clean text with hasToolCall = false
139183
json result;
@@ -142,9 +186,21 @@ std::string ToolCallingBridge::parseToolCall(const std::string& llmOutput) {
142186
return result.dump();
143187
}
144188

145-
// Find end tag
146-
size_t jsonStart = tagStart + strlen(TOOL_CALL_START_TAG);
147-
size_t tagEnd = llmOutput.find(TOOL_CALL_END_TAG, jsonStart);
189+
// Find JSON start position
190+
size_t jsonStart;
191+
if (usingAlternateFormat) {
192+
jsonStart = tagStart + 1 + alternateToolName.length();
193+
while (jsonStart < llmOutput.size() &&
194+
(llmOutput[jsonStart] == ' ' || llmOutput[jsonStart] == '\t' ||
195+
llmOutput[jsonStart] == '\n' || llmOutput[jsonStart] == '>')) {
196+
jsonStart++;
197+
}
198+
} else {
199+
jsonStart = tagStart + strlen(TOOL_CALL_START_TAG);
200+
}
201+
202+
// Find end tag (only for standard format)
203+
size_t tagEnd = usingAlternateFormat ? std::string::npos : llmOutput.find(TOOL_CALL_END_TAG, jsonStart);
148204
bool hasClosingTag = (tagEnd != std::string::npos);
149205

150206
if (!hasClosingTag) {
@@ -204,12 +260,15 @@ std::string ToolCallingBridge::parseToolCall(const std::string& llmOutput) {
204260
return result.dump();
205261
}
206262

207-
// Extract tool name (try "tool" first, then "name")
263+
// Extract tool name (try "tool" first, then "name", then use alternate format tag name)
208264
std::string toolName;
209265
if (toolJson.contains("tool") && toolJson["tool"].is_string()) {
210266
toolName = toolJson["tool"].get<std::string>();
211267
} else if (toolJson.contains("name") && toolJson["name"].is_string()) {
212268
toolName = toolJson["name"].get<std::string>();
269+
} else if (usingAlternateFormat && !alternateToolName.empty()) {
270+
// Use the tag name as tool name (e.g., <search_restaurants {args}> -> "search_restaurants")
271+
toolName = alternateToolName;
213272
} else {
214273
// Could not find tool name
215274
json result;
@@ -219,11 +278,16 @@ std::string ToolCallingBridge::parseToolCall(const std::string& llmOutput) {
219278
}
220279

221280
// Extract arguments (try "arguments" first, then "params")
281+
// For alternate format without explicit arguments, the JSON itself might be the arguments
222282
json arguments = json::object();
223283
if (toolJson.contains("arguments") && toolJson["arguments"].is_object()) {
224284
arguments = toolJson["arguments"];
225285
} else if (toolJson.contains("params") && toolJson["params"].is_object()) {
226286
arguments = toolJson["params"];
287+
} else if (usingAlternateFormat && !toolJson.contains("tool") && !toolJson.contains("name")) {
288+
// In alternate format like <search_restaurants {"query": "food"}>,
289+
// the entire JSON is the arguments
290+
arguments = toolJson;
227291
}
228292

229293
// Build the clean text (everything except the tool call tags)

sdk/runanywhere-react-native/packages/core/src/Public/Extensions/RunAnywhere+ToolCalling.ts

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
*/
1111

1212
import { SDKLogger } from '../../Foundation/Logging/Logger/SDKLogger';
13-
import { generateStream } from './RunAnywhere+TextGeneration';
13+
import { generate } from './RunAnywhere+TextGeneration';
1414
import {
1515
requireNativeModule,
1616
isNativeModuleAvailable,
@@ -254,16 +254,12 @@ export async function generateWithTools(
254254
iterations++;
255255
logger.debug(`[ToolCalling] === Iteration ${iterations} ===`);
256256

257-
// Generate response
258-
let responseText = '';
259-
const streamResult = await generateStream(fullPrompt, {
257+
// Generate response (using non-streaming API to avoid streaming crash)
258+
const generateResult = await generate(fullPrompt, {
260259
maxTokens: options?.maxTokens,
261260
temperature: options?.temperature,
262261
});
263-
264-
for await (const token of streamResult.stream) {
265-
responseText += token;
266-
}
262+
const responseText = generateResult.text;
267263

268264
logger.debug(`[ToolCalling] Raw response (${responseText.length} chars): ${responseText.substring(0, 300)}`);
269265

0 commit comments

Comments
 (0)