Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr 2954 ci branch #3006

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

Pr 2954 ci branch #3006

wants to merge 18 commits into from

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Feb 10, 2025

This PR reopens #2954 and adds some small changes to rely on serde where possible. Thank you @Trofleb for the changes!

This PR aligns the tool calling output to return an array of tool calls as well as serialize the tool arguments as a JSON string.

Important

This PR contains breaking changes and aligns the tool choice output to match openai

@drbh drbh mentioned this pull request Feb 10, 2025
4 tasks
@Trofleb
Copy link

Trofleb commented Feb 10, 2025

Hi @drbh

Thanks for moving forward with my PR !

However, there's something I'm not understanding. Your changes seem to remove my fixes related to streaming. I haven't had time to test the branch but by reading your changes it seems the tool calls will still be impossible to use with streaming using an open ai client.

@drbh drbh force-pushed the pr-2954-ci-branch branch from 5f88bc4 to 3b09662 Compare February 17, 2025 16:19
@drbh
Copy link
Collaborator Author

drbh commented Feb 18, 2025

Hi @Trofleb thank you again for opening this PR. I've made some small changes; namely to avoid attempting to deserialize the string as json at each generation (and some other tweaks for test/ci)

additionally I've added a small test that includes the openai client here integration-tests/models/test_openai_llama_tools.py to ensure the client works.

would you kindly take a look at the PR and let me know if these changes resolve your issue? Thanks!

@Trofleb
Copy link

Trofleb commented Feb 18, 2025

Hi @drbh there's just one thing missing, the name should be only in the first event.

FYI, my test case:

from urllib.parse import urljoin
from langchain_openai import ChatOpenAI


LOCAL_FUNCTION_CALL_LLM = ChatOpenAI(
    model="llama",
    max_tokens=512,
    base_url=urljoin(f"http://localhost:8080", "/v1"),
    api_key="#",
    temperature=0.3,
    frequency_penalty=0,
    disable_streaming=False,
)

from langchain.prompts import ChatPromptTemplate
from openai import BaseModel

weather_prompt = ChatPromptTemplate.from_messages(
    [("human", "What's the weather like in zurich ?")]
)

class GetWeather(BaseModel):
    city: str


get_weather_pipeline = weather_prompt | LOCAL_FUNCTION_CALL_LLM.with_structured_output(
    GetWeather,
    method="function_calling",
    include_raw=True,
)

res = get_weather_pipeline.invoke({})
print(res) # OK

for chunk in get_weather_pipeline.stream({}): # NOK
    print(chunk, end="", flush=True)

The last call ends with the following error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
...

KeyError: 'GetWeatherGetWeatherGetWeatherGetWeatherGetWeatherGetWeatherGetWeatherGetWeatherGetWeather'

Not sure why langchain does that but it concatenates the names at the end of the stream which means he doesn't know which function to call.

@drbh
Copy link
Collaborator Author

drbh commented Feb 18, 2025

@Trofleb thanks for the information, i've just made a small update to only send the name in the first message of the stream. I've tested with the example provided and receive reasonable output

TGI started with

text-generation-launcher --model-id meta-llama/Meta-Llama-3.1-8B-Instruct

example output

python tool-langchain-repro.py

*note: only change made to the example script was to add end="\n" to make the output more a bit more readable

output

{'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': '0', 'function': {'arguments': '{"city":"zurich"}', 'name': 'GetWeather', 'description': None}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 20, 'prompt_tokens': 213, 'total_tokens': 233, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'llama', 'system_fingerprint': '3.1.1-dev0-native', 'finish_reason': 'stop', 'logprobs': None}, id='run-30c7503d-ed3d-4453-92a5-2a6160ba8694-0', tool_calls=[{'name': 'GetWeather', 'args': {'city': 'zurich'}, 'id': '0', 'type': 'tool_call'}], usage_metadata={'input_tokens': 213, 'output_tokens': 20, 'total_tokens': 233, 'input_token_details': {}, 'output_token_details': {}}), 'parsed': GetWeather(city='zurich'), 'parsing_error': None}
{'raw': AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': '', 'function': {'arguments': '{ "city": "zurich"}', 'name': 'GetWeather'}, 'type': 'function'}]}, response_metadata={}, id='run-280b6234-3dc1-4896-b2bf-53896b708053', tool_calls=[{'name': 'GetWeather', 'args': {'city': 'zurich'}, 'id': '', 'type': 'tool_call'}], tool_call_chunks=[{'name': 'GetWeather', 'args': '{ "city": "zurich"}', 'id': '', 'index': 0, 'type': 'tool_call_chunk'}])}
{'parsed': GetWeather(city='zurich')}
{'parsing_error': None}

@drbh drbh force-pushed the pr-2954-ci-branch branch from 22fa0e1 to 3dd0128 Compare February 19, 2025 16:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants