Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 190 additions & 6 deletions litellm/llms/snowflake/chat/transformation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
Support for Snowflake REST API
Support for Snowflake REST API
"""

from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import httpx

from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from litellm.types.utils import ChatCompletionMessageToolCall, Function, ModelResponse

from ...openai_like.chat.transformation import OpenAIGPTConfig

Expand All @@ -22,15 +23,25 @@

class SnowflakeConfig(OpenAIGPTConfig):
"""
source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
Reference: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api

Snowflake Cortex LLM REST API supports function calling with specific models (e.g., Claude 3.5 Sonnet).
This config handles transformation between OpenAI format and Snowflake's tool_spec format.
"""

@classmethod
def get_config(cls):
return super().get_config()

def get_supported_openai_params(self, model: str) -> List:
return ["temperature", "max_tokens", "top_p", "response_format"]
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"temperature",
"max_tokens",
"top_p",
"response_format",
"tools",
"tool_choice",
]

def map_openai_params(
self,
Expand All @@ -56,6 +67,57 @@ def map_openai_params(
optional_params[param] = value
return optional_params

def _transform_tool_calls_from_snowflake_to_openai(
self, content_list: List[Dict[str, Any]]
) -> Tuple[str, Optional[List[ChatCompletionMessageToolCall]]]:
"""
Transform Snowflake tool calls to OpenAI format.

Args:
content_list: Snowflake's content_list array containing text and tool_use items

Returns:
Tuple of (text_content, tool_calls)

Snowflake format in content_list:
{
"type": "tool_use",
"tool_use": {
"tool_use_id": "tooluse_...",
"name": "get_weather",
"input": {"location": "Paris"}
}
}

OpenAI format (returned tool_calls):
ChatCompletionMessageToolCall(
id="tooluse_...",
type="function",
function=Function(name="get_weather", arguments='{"location": "Paris"}')
)
"""
text_content = ""
tool_calls: List[ChatCompletionMessageToolCall] = []

for idx, content_item in enumerate(content_list):
if content_item.get("type") == "text":
text_content += content_item.get("text", "")

## TOOL CALLING
elif content_item.get("type") == "tool_use":
tool_use_data = content_item.get("tool_use", {})
tool_call = ChatCompletionMessageToolCall(
id=tool_use_data.get("tool_use_id", ""),
type="function",
function=Function(
name=tool_use_data.get("name", ""),
arguments=json.dumps(tool_use_data.get("input", {})),
),
)
tool_calls.append(tool_call)

return text_content, tool_calls if tool_calls else None

def transform_response(
self,
model: str,
Expand All @@ -71,13 +133,34 @@ def transform_response(
json_mode: Optional[bool] = None,
) -> ModelResponse:
response_json = raw_response.json()

logging_obj.post_call(
input=messages,
api_key="",
original_response=response_json,
additional_args={"complete_input_dict": request_data},
)

## RESPONSE TRANSFORMATION
# Snowflake returns content_list (not content) with tool_use objects
# We need to transform this to OpenAI's format with content + tool_calls
if "choices" in response_json and len(response_json["choices"]) > 0:
choice = response_json["choices"][0]
if "message" in choice and "content_list" in choice["message"]:
content_list = choice["message"]["content_list"]
(
text_content,
tool_calls,
) = self._transform_tool_calls_from_snowflake_to_openai(content_list)

# Update the choice message with OpenAI format
choice["message"]["content"] = text_content
if tool_calls:
choice["message"]["tool_calls"] = tool_calls

# Remove Snowflake-specific content_list
del choice["message"]["content_list"]

returned_response = ModelResponse(**response_json)

returned_response.model = "snowflake/" + (returned_response.model or "")
Expand Down Expand Up @@ -150,6 +233,95 @@ def get_complete_url(

return api_base

def _transform_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Transform OpenAI tool format to Snowflake tool format.

Args:
tools: List of tools in OpenAI format

Returns:
List of tools in Snowflake format

OpenAI format:
{
"type": "function",
"function": {
"name": "get_weather",
"description": "...",
"parameters": {...}
}
}

Snowflake format:
{
"tool_spec": {
"type": "generic",
"name": "get_weather",
"description": "...",
"input_schema": {...}
}
}
"""
snowflake_tools: List[Dict[str, Any]] = []
for tool in tools:
if tool.get("type") == "function":
function = tool.get("function", {})
snowflake_tool: Dict[str, Any] = {
"tool_spec": {
"type": "generic",
"name": function.get("name"),
"input_schema": function.get(
"parameters",
{"type": "object", "properties": {}},
),
}
}
# Add description if present
if "description" in function:
snowflake_tool["tool_spec"]["description"] = function[
"description"
]

snowflake_tools.append(snowflake_tool)

return snowflake_tools

def _transform_tool_choice(
self, tool_choice: Union[str, Dict[str, Any]]
) -> Union[str, Dict[str, Any]]:
"""
Transform OpenAI tool_choice format to Snowflake format.

Args:
tool_choice: Tool choice in OpenAI format (str or dict)

Returns:
Tool choice in Snowflake format

OpenAI format:
{"type": "function", "function": {"name": "get_weather"}}

Snowflake format:
{"type": "tool", "name": ["get_weather"]}

Note: String values ("auto", "required", "none") pass through unchanged.
"""
if isinstance(tool_choice, str):
# "auto", "required", "none" pass through as-is
return tool_choice

if isinstance(tool_choice, dict):
if tool_choice.get("type") == "function":
function_name = tool_choice.get("function", {}).get("name")
if function_name:
return {
"type": "tool",
"name": [function_name], # Snowflake expects array
}

return tool_choice

def transform_request(
self,
model: str,
Expand All @@ -160,6 +332,18 @@ def transform_request(
) -> dict:
stream: bool = optional_params.pop("stream", None) or False
extra_body = optional_params.pop("extra_body", {})

## TOOL CALLING
# Transform tools from OpenAI format to Snowflake's tool_spec format
tools = optional_params.pop("tools", None)
if tools:
optional_params["tools"] = self._transform_tools(tools)

# Transform tool_choice from OpenAI format to Snowflake's tool name array format
tool_choice = optional_params.pop("tool_choice", None)
if tool_choice:
optional_params["tool_choice"] = self._transform_tool_choice(tool_choice)

return {
"model": model,
"messages": messages,
Expand Down
69 changes: 68 additions & 1 deletion tests/llm_translation/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
load_dotenv()
import pytest

from litellm import completion, acompletion
from litellm import completion, acompletion, responses
from litellm.exceptions import APIConnectionError

@pytest.mark.parametrize("sync_mode", [True, False])
Expand Down Expand Up @@ -87,3 +87,70 @@ async def test_chat_completion_snowflake_stream(sync_mode):
raise # Re-raise if it's a different APIConnectionError
except Exception as e:
pytest.fail(f"Error occurred: {e}")


@pytest.mark.skip(reason="Requires Snowflake credentials - run manually when needed")
def test_snowflake_tool_calling_responses_api():
"""
Test Snowflake tool calling with Responses API.
Requires SNOWFLAKE_JWT and SNOWFLAKE_ACCOUNT_ID environment variables.
"""
import litellm

# Skip if credentials not available
if not os.getenv("SNOWFLAKE_JWT") or not os.getenv("SNOWFLAKE_ACCOUNT_ID"):
pytest.skip("Snowflake credentials not available")

litellm.drop_params = False # We now support tools!

tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
}
]

try:
# Test with tool_choice to force tool use
response = responses(
model="snowflake/claude-3-5-sonnet",
input="What's the weather in Paris?",
tools=tools,
tool_choice={"type": "function", "function": {"name": "get_weather"}},
max_output_tokens=200,
)

assert response is not None
assert hasattr(response, "output")
assert len(response.output) > 0

# Verify tool call was made
tool_call_found = False
for item in response.output:
if hasattr(item, "type") and item.type == "function_call":
tool_call_found = True
assert item.name == "get_weather"
assert hasattr(item, "arguments")
print(f"✅ Tool call detected: {item.name}({item.arguments})")
break

assert tool_call_found, "Expected tool call but none was found"

except APIConnectionError as e:
if "JWT token is invalid" in str(e):
pytest.skip("Invalid Snowflake JWT token")
elif "Application failed to respond" in str(e) or "502" in str(e):
pytest.skip(f"Snowflake API unavailable: {e}")
else:
raise
Loading
Loading