diff --git a/.env.template b/.env.template index 94a67e2..dc76e31 100644 --- a/.env.template +++ b/.env.template @@ -1,8 +1,11 @@ # Required -OPEN_API_KEY=your-key-goes-here +ANTHROPIC_API_KEY=your-anthropic-key-goes-here JWT_SECRET_KEY=your-key-goes-here UDI_MODEL_NAME=HIDIVE/UDI-VIS-Beta-v2-Llama-3.1-8B +# Claude model for orchestration (optional, defaults to claude-sonnet-4-6) +# CLAUDE_MODEL_NAME=claude-sonnet-4-6 + # Optional — uncomment and set as needed # UDI_TOKENIZER_NAME=HIDIVE/UDI-VIS-Beta-v2-Llama-3.1-8B # INSECURE_DEV_MODE=1 diff --git a/README.md b/README.md index e0474b5..8efb9b6 100755 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ To run the entrypoint into the multi-agent system I run a simple python api. `fastapi run ./src/udi_api.py` -This is the endpoint that is called by the YAC frontend. The `udi_api.py` script makes calls to openai and the finetuned model running with vllm. +This is the endpoint that is called by the YAC frontend. The `udi_api.py` script makes calls to Anthropic Claude and the finetuned model running with vllm. ### set environment variables @@ -16,7 +16,8 @@ This is the endpoint that is called by the YAC frontend. The `udi_api.py` script | Item | Command / Value | Description | | ------------------- | ------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| OPEN_API_KEY | `OPEN_API_KEY=your-key-goes-here` | Required. The multi-agent system currently maeks calls to open ai. | +| ANTHROPIC_API_KEY | `ANTHROPIC_API_KEY=your-key-goes-here` | Required. API key for Anthropic Claude, used for orchestration and structured outputs. | +| CLAUDE_MODEL_NAME | `CLAUDE_MODEL_NAME=claude-sonnet-4-6` | Optional. Claude model for orchestration. Defaults to `claude-sonnet-4-6`. | | JWT_SECRET_KEY | `JWT_SECRET_KEY=your-key-goes-here` | Required. Secret key for JWT generation. | | UDI_MODEL_NAME | `UDI_MODEL_NAME=HIDIVE/UDI-VIS-Beta-v2-Llama-3.1-8B` | Required. Path to local or public model name, depending on how the model is served via vllm. Note, this is the model name for the fine-tuned visualization generation model, not a foundation model. | | VLLM_SERVER_URL | `VLLM_SERVER_URL=http://localhost` | Optional. Hostname of the vllm server. Defaults to `http://localhost`. | diff --git a/pyproject.toml b/pyproject.toml index ceecd36..10e460e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.13" dependencies = [ + "anthropic>=0.52.0", "datasets>=4.5.0", "fastapi[standard]>=0.128.0", "huggingface-hub>=0.36.0", diff --git a/src/udi_agent.py b/src/udi_agent.py index 4d3d7fc..d6d500e 100644 --- a/src/udi_agent.py +++ b/src/udi_agent.py @@ -1,5 +1,7 @@ import json from langfuse.openai import OpenAI +from langfuse import observe +import anthropic from jinja2 import Template from transformers import AutoTokenizer @@ -21,114 +23,114 @@ class UDIAgent: def __init__( self, model_name: str, - gpt_model_name: str, + claude_model_name: str, vllm_server_url=None, vllm_server_port=None, tokenizer_name: str = None, ): self.model_name = model_name self.tokenizer_name = tokenizer_name or model_name - self.gpt_model_name = gpt_model_name + self.claude_model_name = claude_model_name if vllm_server_port is not None and vllm_server_url is not None: self.vllm_server_url = vllm_server_url self.vllm_server_port = vllm_server_port self.init_model_connection() - # else: - # self.init_internal_models() def init_model_connection(self): base_url = f"{self.vllm_server_url}:{self.vllm_server_port}/v1" + # vLLM client stays as OpenAI-compatible (unchanged) self.model = OpenAI( api_key="EMPTY", base_url=base_url, ) - self.gpt_model = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + # Anthropic Claude client (replaces OpenAI GPT) + self.claude_model = anthropic.Anthropic( + api_key=os.getenv("ANTHROPIC_API_KEY") + ) # Cache tokenizer + chat template once (avoid reloading per request) tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) self.tokenizer = tokenizer self.chat_template = Template(tokenizer.chat_template) - # def init_internal_models(self): - # self.llm = LLM( - # model=self.model_name, - # gpu_memory_utilization=0.85, - # # dtype="auto", # what serve typically auto-picks on A100 - # # max_model_len=4096, # ↓ KV cache pressure; raise later if needed - # # max_num_seqs=8, # cap concurrent sequences/batch size - # # swap_space=0, # GiB CPU RAM to spill KV blocks - # # enforce_eager=False, # helps fragmentation on some stacks - # # kv_cache_dtype="fp8", # if your vLLM build+model support it, huge win - # ) - - # def basic_test(prompt): - # params = SamplingParams( - # temperature=0.0, - # n=3) - # response = self.llm.generate(prompt, sampling_params=params) - # return response + @staticmethod + def _split_system_messages(messages: list[dict]) -> tuple[str, list[dict]]: + """Extract system messages from the messages list. + + Anthropic requires system messages as a separate parameter, + not in the messages array. + + Returns (system_text, non_system_messages). + """ + system_parts = [] + other_messages = [] + for msg in messages: + if msg.get("role") == "system": + system_parts.append(msg["content"]) + else: + other_messages.append(msg) + return "\n\n".join(system_parts), other_messages def chat_completions(self, messages: list[dict]): response = self.model.chat.completions.create( model=self.model_name, messages=messages, - # max_tokens=40960, max_tokens=120_000, temperature=0.0, - # top_p=1.0, ) return response - # def completions_guided_choice(self, messages: list[dict], tools: list[dict], choices: list[str]): - # tokenizer = AutoTokenizer.from_pretrained(self.model_name) - # chat_template = Template(tokenizer.chat_template) - - # prompt = chat_template.render( - # messages=messages, tools=tools, add_generation_prompt=True - # ) - - # response = self.gpt_model.completions.create( - # model=self.gpt_model_name, - # prompt=prompt, - # max_tokens=100, - # extra_body={ - # "guided_choice": choices, - # } - - # ) - # return response - + @observe() def completions_guided_choice( self, messages: list[dict], tools: list[dict], choices: list[str] ): - schema = { + """Use Anthropic tool use to force a structured choice selection.""" + system_text, user_messages = self._split_system_messages(messages) + + # Define a tool that forces the LLM to return one of the choices + choice_tool = { "name": "ChoiceSelection", - "schema": { + "description": "Select the appropriate choice based on the user's request.", + "input_schema": { "type": "object", - "additionalProperties": False, - "properties": {"choice": {"type": "string", "enum": choices}}, + "properties": { + "choice": { + "type": "string", + "enum": choices, + "description": "The selected choice.", + } + }, "required": ["choice"], }, - "strict": True, } - resp = self.gpt_model.chat.completions.create( - model=self.gpt_model_name, # e.g. "gpt-4.1-mini" - messages=messages, # [{"role":"user","content":"..."}] - response_format={ # <-- key part - "type": "json_schema", - "json_schema": schema, - }, - max_tokens=10, + resp = self.claude_model.messages.create( + model=self.claude_model_name, + system=system_text, + messages=user_messages, + tools=[choice_tool], + tool_choice={"type": "tool", "name": "ChoiceSelection"}, + max_tokens=256, temperature=0.0, ) - # Parse - content = resp.choices[0].message.content - # content is guaranteed to be valid JSON per schema - return json.loads(content)["choice"] - def gpt_completions_guided_json(self, messages: list[dict], json_schema: str, n=1): + # Extract tool use result from response + for block in resp.content: + if block.type == "tool_use": + return block.input["choice"] + + raise ValueError("No tool_use block in Claude response") + + @observe() + def claude_completions_guided_json( + self, messages: list[dict], json_schema: str, n=1 + ): + """Use Anthropic tool use to get structured JSON output. + + Replaces gpt_completions_guided_json. Uses a tool definition to + constrain the output to the given JSON schema. + """ # Normalize schema to dict if isinstance(json_schema, str): try: @@ -138,29 +140,40 @@ def gpt_completions_guided_json(self, messages: list[dict], json_schema: str, n= else: schema_obj = json_schema - # Wrap for Structured Outputs (required shape) - schema_wrapper = { + system_text, user_messages = self._split_system_messages(messages) + + # Define a tool whose input_schema matches the desired output + guided_tool = { "name": "GuidedJSON", - "schema": schema_obj, - "strict": True, # disallow extra fields + "description": "Return the structured JSON output matching the required schema.", + "input_schema": schema_obj, } - resp = self.gpt_model.chat.completions.create( - model=self.gpt_model_name, # e.g. "gpt-4.1-mini" - messages=messages, # [{"role": "user", "content": "..."}] - response_format={ - "type": "json_schema", - "json_schema": schema_wrapper, - }, - n=n, - temperature=0.0, - max_tokens=16_384, - ) + outputs = [] + for _ in range(n): + resp = self.claude_model.messages.create( + model=self.claude_model_name, + system=system_text, + messages=user_messages, + tools=[guided_tool], + tool_choice={"type": "tool", "name": "GuidedJSON"}, + max_tokens=16_384, + temperature=0.0, + ) + + for block in resp.content: + if block.type == "tool_use": + outputs.append(block.input) + break + else: + outputs.append({}) - # Each choice is guaranteed to be valid JSON per schema - outputs = [json.loads(choice.message.content) for choice in resp.choices] return outputs + # Keep old method name as alias for backward compatibility + def gpt_completions_guided_json(self, messages: list[dict], json_schema: str, n=1): + return self.claude_completions_guided_json(messages, json_schema, n) + def completions_guided_json( self, messages: list[dict], tools: list[dict], json_schema: str, n=1 ): @@ -181,86 +194,6 @@ def completions_guided_json( ) return response - # def test_completions_with_more_params(self): - # tokenizer = AutoTokenizer.from_pretrained(self.model_name) - # chat_template = Template(tokenizer.chat_template) - - # messages = [ - # {"role": "system", "content": "You are a helpful assistant."}, - # {"role": "user", "content": "What is the address of the white house?"}, - # ] - # tools = [] - - # print(f"Messages: {messages}") - - # # Debugging, remove all the tool_calls from the messages - # # TODO: try reformatting as strings. - # for message in messages: - # if 'tool_calls' in message: - # del message['tool_calls'] - - # prompt = chat_template.render( - # messages=messages, tools=tools, add_generation_prompt=True) - - # print(f"Prompt: {prompt}") - - # guided_json = ''' - # { - # "$id": "https://example.com/address.schema.json", - # "$schema": "https://json-schema.org/draft/2020-12/schema", - # "description": "An address similar to http://microformats.org/wiki/h-card", - # "type": "object", - # "properties": { - # "postOfficeBox": { - # "type": "string" - # }, - # "extendedAddress": { - # "type": "string" - # }, - # "streetAddress": { - # "type": "string" - # }, - # "thingy77": { - # "type": "string" - # }, - # "region": { - # "type": "string" - # }, - # "postalCode": { - # "type": "string" - # }, - # "countryName": { - # "type": "string" - # } - # }, - # "required": [ "thingy77", "region", "countryName" ], - # "dependentRequired": { - # "postOfficeBox": [ "streetAddress" ], - # "extendedAddress": [ "streetAddress" ] - # } - # } - # ''' - # response = self.model.completions.create( - # model=self.model_name, - # prompt=prompt, - # max_tokens=16_384, - # # n=3, - # # logprobs=3, - # temperature=0.7, - # extra_body={ - # "guided_json": guided_json, - # # "use_beam_search": True - # } - # # sampling_params= { - # # 'use_beam_search':True, - # # } - # # best_of=3, - # # top_p=1.0, - # # use_beam_search=True, - # # guided_json=guided_json - # ) - # return response - def completions(self, messages: list[dict], tools: list[dict]): print(f"Messages: {messages}") @@ -281,6 +214,5 @@ def completions(self, messages: list[dict], tools: list[dict]): prompt=prompt, max_tokens=16_384, temperature=0.0, - # top_p=1.0, ) return response diff --git a/src/udi_api.py b/src/udi_api.py index d84eac6..88916ff 100644 --- a/src/udi_api.py +++ b/src/udi_api.py @@ -38,13 +38,12 @@ allow_headers=["*"], ) +CLAUDE_MODEL_NAME = os.getenv("CLAUDE_MODEL_NAME", "claude-sonnet-4-6") + # init agent agent = UDIAgent( - # model_name="agenticx/UDI-VIS-Beta-v0-Llama-3.1-8B", model_name=MODEL_NAME, - gpt_model_name="gpt-4.1", - # gpt_model_name="gpt-4.1-nano", - # gpt_model_name="gpt-5-nano", + claude_model_name=CLAUDE_MODEL_NAME, vllm_server_url=VLLM_SERVER_URL, vllm_server_port=VLLM_SERVER_PORT, tokenizer_name=TOKENIZER_NAME, diff --git a/src/vis_generate.py b/src/vis_generate.py index b26e4af..cd45f7b 100644 --- a/src/vis_generate.py +++ b/src/vis_generate.py @@ -188,20 +188,38 @@ def load_grammar(grammar_name, base_path="./src"): def _call_llm_with_tools(agent, messages, tools, config): - """Call the LLM with function-calling tools. Returns (tool_name, arguments) or None.""" + """Call the LLM with function-calling tools. Returns (tool_name, arguments) or None. + + Converts OpenAI-style tool definitions to Anthropic format and uses + the Anthropic Claude client. + """ try: - resp = agent.gpt_model.chat.completions.create( - model=agent.gpt_model_name, - messages=messages, - tools=tools, - tool_choice="auto", + # Convert OpenAI tool format to Anthropic format + anthropic_tools = [] + for tool in tools: + func = tool.get("function", tool) + anthropic_tools.append({ + "name": func["name"], + "description": func.get("description", ""), + "input_schema": func.get("parameters", func.get("input_schema", {})), + }) + + # Extract system messages + system_text, user_messages = agent._split_system_messages(messages) + + resp = agent.claude_model.messages.create( + model=agent.claude_model_name, + system=system_text, + messages=user_messages, + tools=anthropic_tools, + tool_choice={"type": "auto"}, temperature=0.0, max_tokens=1024, ) - choice = resp.choices[0] - if choice.message.tool_calls: - tc = choice.message.tool_calls[0] - return tc.function.name, json.loads(tc.function.arguments) + + for block in resp.content: + if block.type == "tool_use": + return block.name, block.input except Exception: pass return None @@ -210,7 +228,7 @@ def _call_llm_with_tools(agent, messages, tools, config): def _call_llm(agent, messages, grammar, config, backend): """Call the LLM and return the raw spec string.""" if backend == "gpt": - results = agent.gpt_completions_guided_json( + results = agent.claude_completions_guided_json( messages=messages, json_schema=grammar["schema_string"], n=config.get("n", 1),