Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tool call message processing #3036

Merged
merged 8 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, InputToken
from text_generation.types import FinishReason


def test_generate(llama_7b_url, hf_headers):
Expand Down
68 changes: 52 additions & 16 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1865,25 +1865,57 @@
}
},
"Message": {
"type": "object",
"required": [
"role",
"content"
],
"properties": {
"content": {
"$ref": "#/components/schemas/MessageContent"
"allOf": [
{
"$ref": "#/components/schemas/MessageBody"
},
"name": {
"type": "string",
"example": "\"David\"",
"nullable": true
{
"type": "object",
"required": [
"role"
],
"properties": {
"name": {
"type": "string",
"example": "\"David\"",
"nullable": true
},
"role": {
"type": "string",
"example": "user"
}
}
}
]
},
"MessageBody": {
"oneOf": [
{
"type": "object",
"required": [
"content"
],
"properties": {
"content": {
"$ref": "#/components/schemas/MessageContent"
}
}
},
"role": {
"type": "string",
"example": "user"
{
"type": "object",
"required": [
"tool_calls"
],
"properties": {
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolCall"
}
}
}
}
}
]
},
"MessageChunk": {
"oneOf": [
Expand Down Expand Up @@ -2179,6 +2211,10 @@
"role": {
"type": "string",
"example": "user"
},
"tool_call_id": {
"type": "string",
"nullable": true
}
}
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information.",
"name": null,
"role": "assistant",
"tool_calls": null
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we skip that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name and tool_calls are actually skipped in the response, but the Python client library adds None when deserializing the response.

the actually response payload is

{
  "object": "chat.completion",
  "id": "",
  "created": 1740011163,
  "model": "meta-llama/Llama-3.1-8B-Instruct",
  "system_fingerprint": "3.1.1-dev0-native",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7\u00b0C (44.1\u00b0F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information."
      },
      "logprobs": null,
      "finish_reason": "stop"
    }
  ],
  "usage": {
    "prompt_tokens": 103,
    "completion_tokens": 79,
    "total_tokens": 182
  }
}

},
"usage": null
}
],
"created": 1739932427,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.1.1-dev0-native",
"usage": {
"completion_tokens": 79,
"prompt_tokens": 103,
"total_tokens": 182
}
}
38 changes: 38 additions & 0 deletions integration-tests/models/test_tools_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,41 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
)
assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_tool_reply_response(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=42,
messages=[
{"role": "user", "content": "What's the weather like in Paris today?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "0",
"function": {
"arguments": '{"longitude": 2.2945, "latitude": 48.8567}',
"name": "get_weather",
"description": None,
},
"type": "function",
}
],
},
{"role": "tool", "tool_call_id": "0", "content": "6.7"},
],
stream=False,
)

assert responses.choices[0].message.tool_calls is None
assert (
responses.choices[0].message.content
== "I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\n\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \n\nPlease note that the actual weather may differ from this information, and I recommend checking the forecast on a reliable weather website for the most up-to-date information."
)

assert responses == response_snapshot
Loading
Loading