11"""
2- Support for Snowflake REST API
2+ Support for Snowflake REST API
33"""
44
5- from typing import TYPE_CHECKING , Any , List , Optional , Tuple
5+ import json
6+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union
67
78import httpx
89
910from litellm .secret_managers .main import get_secret_str
1011from litellm .types .llms .openai import AllMessageValues
11- from litellm .types .utils import ModelResponse
12+ from litellm .types .utils import ChatCompletionMessageToolCall , Function , ModelResponse
1213
1314from ...openai_like .chat .transformation import OpenAIGPTConfig
1415
2223
2324class SnowflakeConfig (OpenAIGPTConfig ):
2425 """
25- source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
26+ Reference: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api
27+
28+ Snowflake Cortex LLM REST API supports function calling with specific models (e.g., Claude 3.5 Sonnet).
29+ This config handles transformation between OpenAI format and Snowflake's tool_spec format.
2630 """
2731
2832 @classmethod
2933 def get_config (cls ):
3034 return super ().get_config ()
3135
32- def get_supported_openai_params (self , model : str ) -> List :
33- return ["temperature" , "max_tokens" , "top_p" , "response_format" ]
36+ def get_supported_openai_params (self , model : str ) -> List [str ]:
37+ return [
38+ "temperature" ,
39+ "max_tokens" ,
40+ "top_p" ,
41+ "response_format" ,
42+ "tools" ,
43+ "tool_choice" ,
44+ ]
3445
3546 def map_openai_params (
3647 self ,
@@ -56,6 +67,57 @@ def map_openai_params(
5667 optional_params [param ] = value
5768 return optional_params
5869
70+ def _transform_tool_calls_from_snowflake_to_openai (
71+ self , content_list : List [Dict [str , Any ]]
72+ ) -> Tuple [str , Optional [List [ChatCompletionMessageToolCall ]]]:
73+ """
74+ Transform Snowflake tool calls to OpenAI format.
75+
76+ Args:
77+ content_list: Snowflake's content_list array containing text and tool_use items
78+
79+ Returns:
80+ Tuple of (text_content, tool_calls)
81+
82+ Snowflake format in content_list:
83+ {
84+ "type": "tool_use",
85+ "tool_use": {
86+ "tool_use_id": "tooluse_...",
87+ "name": "get_weather",
88+ "input": {"location": "Paris"}
89+ }
90+ }
91+
92+ OpenAI format (returned tool_calls):
93+ ChatCompletionMessageToolCall(
94+ id="tooluse_...",
95+ type="function",
96+ function=Function(name="get_weather", arguments='{"location": "Paris"}')
97+ )
98+ """
99+ text_content = ""
100+ tool_calls : List [ChatCompletionMessageToolCall ] = []
101+
102+ for idx , content_item in enumerate (content_list ):
103+ if content_item .get ("type" ) == "text" :
104+ text_content += content_item .get ("text" , "" )
105+
106+ ## TOOL CALLING
107+ elif content_item .get ("type" ) == "tool_use" :
108+ tool_use_data = content_item .get ("tool_use" , {})
109+ tool_call = ChatCompletionMessageToolCall (
110+ id = tool_use_data .get ("tool_use_id" , "" ),
111+ type = "function" ,
112+ function = Function (
113+ name = tool_use_data .get ("name" , "" ),
114+ arguments = json .dumps (tool_use_data .get ("input" , {})),
115+ ),
116+ )
117+ tool_calls .append (tool_call )
118+
119+ return text_content , tool_calls if tool_calls else None
120+
59121 def transform_response (
60122 self ,
61123 model : str ,
@@ -71,13 +133,34 @@ def transform_response(
71133 json_mode : Optional [bool ] = None ,
72134 ) -> ModelResponse :
73135 response_json = raw_response .json ()
136+
74137 logging_obj .post_call (
75138 input = messages ,
76139 api_key = "" ,
77140 original_response = response_json ,
78141 additional_args = {"complete_input_dict" : request_data },
79142 )
80143
144+ ## RESPONSE TRANSFORMATION
145+ # Snowflake returns content_list (not content) with tool_use objects
146+ # We need to transform this to OpenAI's format with content + tool_calls
147+ if "choices" in response_json and len (response_json ["choices" ]) > 0 :
148+ choice = response_json ["choices" ][0 ]
149+ if "message" in choice and "content_list" in choice ["message" ]:
150+ content_list = choice ["message" ]["content_list" ]
151+ (
152+ text_content ,
153+ tool_calls ,
154+ ) = self ._transform_tool_calls_from_snowflake_to_openai (content_list )
155+
156+ # Update the choice message with OpenAI format
157+ choice ["message" ]["content" ] = text_content
158+ if tool_calls :
159+ choice ["message" ]["tool_calls" ] = tool_calls
160+
161+ # Remove Snowflake-specific content_list
162+ del choice ["message" ]["content_list" ]
163+
81164 returned_response = ModelResponse (** response_json )
82165
83166 returned_response .model = "snowflake/" + (returned_response .model or "" )
@@ -150,6 +233,95 @@ def get_complete_url(
150233
151234 return api_base
152235
236+ def _transform_tools (self , tools : List [Dict [str , Any ]]) -> List [Dict [str , Any ]]:
237+ """
238+ Transform OpenAI tool format to Snowflake tool format.
239+
240+ Args:
241+ tools: List of tools in OpenAI format
242+
243+ Returns:
244+ List of tools in Snowflake format
245+
246+ OpenAI format:
247+ {
248+ "type": "function",
249+ "function": {
250+ "name": "get_weather",
251+ "description": "...",
252+ "parameters": {...}
253+ }
254+ }
255+
256+ Snowflake format:
257+ {
258+ "tool_spec": {
259+ "type": "generic",
260+ "name": "get_weather",
261+ "description": "...",
262+ "input_schema": {...}
263+ }
264+ }
265+ """
266+ snowflake_tools : List [Dict [str , Any ]] = []
267+ for tool in tools :
268+ if tool .get ("type" ) == "function" :
269+ function = tool .get ("function" , {})
270+ snowflake_tool : Dict [str , Any ] = {
271+ "tool_spec" : {
272+ "type" : "generic" ,
273+ "name" : function .get ("name" ),
274+ "input_schema" : function .get (
275+ "parameters" ,
276+ {"type" : "object" , "properties" : {}},
277+ ),
278+ }
279+ }
280+ # Add description if present
281+ if "description" in function :
282+ snowflake_tool ["tool_spec" ]["description" ] = function [
283+ "description"
284+ ]
285+
286+ snowflake_tools .append (snowflake_tool )
287+
288+ return snowflake_tools
289+
290+ def _transform_tool_choice (
291+ self , tool_choice : Union [str , Dict [str , Any ]]
292+ ) -> Union [str , Dict [str , Any ]]:
293+ """
294+ Transform OpenAI tool_choice format to Snowflake format.
295+
296+ Args:
297+ tool_choice: Tool choice in OpenAI format (str or dict)
298+
299+ Returns:
300+ Tool choice in Snowflake format
301+
302+ OpenAI format:
303+ {"type": "function", "function": {"name": "get_weather"}}
304+
305+ Snowflake format:
306+ {"type": "tool", "name": ["get_weather"]}
307+
308+ Note: String values ("auto", "required", "none") pass through unchanged.
309+ """
310+ if isinstance (tool_choice , str ):
311+ # "auto", "required", "none" pass through as-is
312+ return tool_choice
313+
314+ if isinstance (tool_choice , dict ):
315+ if tool_choice .get ("type" ) == "function" :
316+ function_name = tool_choice .get ("function" , {}).get ("name" )
317+ if function_name :
318+ return {
319+ "type" : "tool" ,
320+ "name" : [function_name ], # Snowflake expects array
321+ }
322+
323+ return tool_choice
324+
153325 def transform_request (
154326 self ,
155327 model : str ,
@@ -160,6 +332,18 @@ def transform_request(
160332 ) -> dict :
161333 stream : bool = optional_params .pop ("stream" , None ) or False
162334 extra_body = optional_params .pop ("extra_body" , {})
335+
336+ ## TOOL CALLING
337+ # Transform tools from OpenAI format to Snowflake's tool_spec format
338+ tools = optional_params .pop ("tools" , None )
339+ if tools :
340+ optional_params ["tools" ] = self ._transform_tools (tools )
341+
342+ # Transform tool_choice from OpenAI format to Snowflake's tool name array format
343+ tool_choice = optional_params .pop ("tool_choice" , None )
344+ if tool_choice :
345+ optional_params ["tool_choice" ] = self ._transform_tool_choice (tool_choice )
346+
163347 return {
164348 "model" : model ,
165349 "messages" : messages ,
0 commit comments