diff --git a/core/changelog.md b/core/changelog.md index 5301cc548..4b63a0795 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -3,4 +3,5 @@ - fix: add support for AdditionalProperties structures (both boolean and object types) - fix: improve thought signature handling in gemini for function calls - fix: enhance citations structure to support multiple citation types -- fix: anthropic streaming events through integration \ No newline at end of file +- fix: anthropic streaming events through integration +- feat: added support for code execution tool for openai, anthropic and gemini \ No newline at end of file diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index e46edfcbf..8cbdb35db 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -2095,6 +2095,17 @@ func (response *AnthropicMessageResponse) ToBifrostResponsesResponse() *schemas. }, } outputMessages := ConvertAnthropicMessagesToBifrostMessages([]AnthropicMessage{tempMsg}, nil, true, false) + // Add container ID and container expire at using the Anthropic Response + if response.Container != nil { + // Find the corresponding code interpreter call by type=code_interpreter_call + for i, msg := range outputMessages { + if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeCodeInterpreterCall && msg.ResponsesCodeInterpreterToolCall != nil { + outputMessages[i].ResponsesToolMessage.ResponsesCodeInterpreterToolCall.ContainerID = response.Container.ID + outputMessages[i].ResponsesToolMessage.ResponsesCodeInterpreterToolCall.ExpiresAt = &response.Container.ExpiresAt + break + } + } + } if len(outputMessages) > 0 { bifrostResp.Output = outputMessages } @@ -2164,6 +2175,27 @@ func ToAnthropicResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) } } + // Extract container information from code interpreter calls + if bifrostResp.Output != nil { + for _, msg := range bifrostResp.Output { + if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeCodeInterpreterCall && + msg.ResponsesToolMessage != nil && + msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall != nil { + codeInterpreter := msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall + if codeInterpreter.ContainerID != "" { + container := &AnthropicContainer{ + ID: codeInterpreter.ContainerID, + } + if codeInterpreter.ExpiresAt != nil && *codeInterpreter.ExpiresAt != "" { + container.ExpiresAt = *codeInterpreter.ExpiresAt + } + anthropicResp.Container = container + break // Only need the first code interpreter container + } + } + } + } + anthropicResp.Model = bifrostResp.Model return anthropicResp @@ -2592,9 +2624,43 @@ func ConvertBifrostMessagesToAnthropicMessages(bifrostMessages []schemas.Respons } } + case schemas.ResponsesMessageTypeCodeInterpreterCall: + // Flush any pending tool results before processing code interpreter calls + flushPendingToolResults() + + // Code interpreter calls need special handling: create server_tool_use + bash_code_execution_tool_result blocks + codeInterpreterBlocks := convertBifrostCodeInterpreterCallToAnthropicBlocks(&msg) + if len(codeInterpreterBlocks) > 0 { + // For code interpreter, we create both server_tool_use and bash_code_execution_tool_result + // These should appear in an assistant message + if currentAssistantMessage == nil { + currentAssistantMessage = &AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + } + } + + // Prepend any pending reasoning blocks to ensure they come BEFORE tool blocks + if len(pendingReasoningContentBlocks) > 0 { + copied := make([]AnthropicContentBlock, len(pendingReasoningContentBlocks)) + copy(copied, pendingReasoningContentBlocks) + pendingToolCalls = append(copied, pendingToolCalls...) + pendingReasoningContentBlocks = nil + } + + // Add the code interpreter blocks (server_tool_use + bash_code_execution_tool_result) + pendingToolCalls = append(pendingToolCalls, codeInterpreterBlocks...) + + // Track the tool call ID for the server_tool_use block (first block) + if len(codeInterpreterBlocks) > 0 && codeInterpreterBlocks[0].ID != nil { + if currentToolCallIDs == nil { + currentToolCallIDs = make(map[string]bool) + } + currentToolCallIDs[*codeInterpreterBlocks[0].ID] = true + } + } + // Handle other tool call types that are not natively supported by Anthropic case schemas.ResponsesMessageTypeFileSearchCall, - schemas.ResponsesMessageTypeCodeInterpreterCall, schemas.ResponsesMessageTypeLocalShellCall, schemas.ResponsesMessageTypeCustomToolCall, schemas.ResponsesMessageTypeImageGenerationCall: @@ -3205,7 +3271,6 @@ func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicC bifrostMessages = append(bifrostMessages, bifrostMsg) } } - case AnthropicContentBlockTypeServerToolUse: // Check if it's a web_search tool if block.Name != nil && *block.Name == string(AnthropicToolNameWebSearch) { @@ -3238,6 +3303,28 @@ func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicC } } + // Check if its code execution tool use + if block.Name != nil && *block.Name == string(AnthropicToolNameBashCodeExecution) { + bifrostMsg := schemas.ResponsesMessage{ + ID: block.ID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeCodeInterpreterCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + ResponsesCodeInterpreterToolCall: &schemas.ResponsesCodeInterpreterToolCall{ + Outputs: []schemas.ResponsesCodeInterpreterOutput{}, + }, + }, + } + if block.Input != nil { + if inputMap, ok := block.Input.(map[string]interface{}); ok { + if code, ok := inputMap["command"].(string); ok { + bifrostMsg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Code = &code + } + } + } + // Set container ID and container expire at using the Anthropic Response + bifrostMessages = append(bifrostMessages, bifrostMsg) + } case AnthropicContentBlockTypeWebSearchResult: // Find the corresponding web_search_call by tool_use_id if block.ToolUseID != nil { @@ -3335,6 +3422,34 @@ func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicC } bifrostMessages = append(bifrostMessages, bifrostMsg) } + case AnthropicContentBlockTypeBashCodeExecutionToolResult: + // find the corresponding code interpreter call by tool_use_id + if block.ToolUseID != nil { + for i := len(bifrostMessages) - 1; i >= 0; i-- { + msg := &bifrostMessages[i] + if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeCodeInterpreterCall && + msg.ResponsesToolMessage != nil && + msg.ResponsesToolMessage.CallID != nil && + *msg.ResponsesToolMessage.CallID == *block.ToolUseID { + codeExecutionBlock := block.Content.ContentBlock + if codeExecutionBlock == nil || codeExecutionBlock.Type != AnthropicContentBlockTypeBashCodeExecutionResult { + continue + } + // Add this result to the code interpreter call outputs + var log schemas.ResponsesCodeInterpreterOutputLogs + log.Type = "logs" + log.ReturnCode = codeExecutionBlock.ReturnCode + if codeExecutionBlock.StdOut != nil { + log.Logs = *codeExecutionBlock.StdOut + } else if codeExecutionBlock.StdErr != nil { + log.Logs = *codeExecutionBlock.StdErr + } + msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Outputs = append(msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Outputs, schemas.ResponsesCodeInterpreterOutput{ + ResponsesCodeInterpreterOutputLogs: &log, + }) + } + } + } default: // Handle other block types if needed } @@ -3671,25 +3786,6 @@ func convertBifrostMCPCallToAnthropicToolUse(msg *schemas.ResponsesMessage) *Ant return nil } -// convertBifrostMCPCallOutputToAnthropicMessage converts a Bifrost MCP call output to Anthropic message -func convertBifrostMCPCallOutputToAnthropicMessage(msg *schemas.ResponsesMessage) *AnthropicMessage { - toolResultBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeMCPToolResult, - ID: msg.ResponsesToolMessage.CallID, - } - - if msg.ResponsesToolMessage.Output != nil { - toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) - } - - return &AnthropicMessage{ - Role: AnthropicMessageRoleUser, - Content: AnthropicContent{ - ContentBlocks: []AnthropicContentBlock{toolResultBlock}, - }, - } -} - // convertBifrostMCPApprovalToAnthropicToolUse converts a Bifrost MCP approval request to Anthropic tool use func convertBifrostMCPApprovalToAnthropicToolUse(msg *schemas.ResponsesMessage) *AnthropicContentBlock { if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { @@ -3777,6 +3873,101 @@ func convertBifrostWebSearchCallToAnthropicBlocks(msg *schemas.ResponsesMessage) return blocks } +// convertBifrostCodeInterpreterCallToAnthropicBlocks converts a Bifrost code_interpreter_call to Anthropic server_tool_use and bash_code_execution_tool_result blocks +func convertBifrostCodeInterpreterCallToAnthropicBlocks(msg *schemas.ResponsesMessage) []AnthropicContentBlock { + if msg.ResponsesToolMessage == nil || msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall == nil { + return nil + } + + var blocks []AnthropicContentBlock + codeInterpreter := msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall + + // Ensure we have a valid ID for the tool use block (critical for linkage) + if msg.ID == nil { + // Cannot proceed without an ID - this would break tool_use_id linkage + return nil + } + + // 1. Create server_tool_use block for the code interpreter + serverToolUseBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeServerToolUse, + Name: schemas.Ptr(string(AnthropicToolNameBashCodeExecution)), + ID: msg.ID, // Always set - required for tool result linkage + } + + // Wrap the code in bash command format: python3 << 'EOF'\n...\nEOF\n + if codeInterpreter.Code != nil { + command := *codeInterpreter.Code + // Wrap in heredoc format if not already wrapped + if !strings.HasPrefix(command, "python") && !strings.HasPrefix(command, "bash") { + command = fmt.Sprintf("python3 << 'EOF'\n%s\nEOF\n", command) + } + input := map[string]interface{}{ + "command": command, + } + serverToolUseBlock.Input = input + } + + blocks = append(blocks, serverToolUseBlock) + + // 2. Create bash_code_execution_tool_result block if outputs are present + if len(codeInterpreter.Outputs) > 0 { + for _, output := range codeInterpreter.Outputs { + // Initialize stdout and stderr with empty strings (Anthropic expects these fields) + stdout := "" + stderr := "" + var returnCode *int + + // Handle logs output + if output.ResponsesCodeInterpreterOutputLogs != nil { + logs := output.ResponsesCodeInterpreterOutputLogs + returnCode = logs.ReturnCode + if returnCode != nil { + if *returnCode == 0 { + stdout = logs.Logs + } else { + stderr = logs.Logs + } + } else { + // If return code is not present, use the logs as stdout + stdout = logs.Logs + } + } + + // Create the bash_code_execution_result content block + // This must include type, stdout, stderr, return_code, and an empty content array + bashResultContent := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeBashCodeExecutionResult, + StdOut: &stdout, + StdErr: &stderr, + ReturnCode: returnCode, + Content: &AnthropicContent{ + ContentBlocks: []AnthropicContentBlock{}, // Empty array as per Anthropic spec + }, + } + + // Ensure we have a valid tool_use_id (critical for linkage) + if msg.ID == nil { + // Skip if no ID - this would be a hard break + continue + } + + // Create the bash_code_execution_tool_result block + bashResultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeBashCodeExecutionToolResult, + ToolUseID: msg.ID, // Must match the server_tool_use.id + Content: &AnthropicContent{ + ContentBlock: &bashResultContent, + }, + } + + blocks = append(blocks, bashResultBlock) + } + } + + return blocks +} + // convertBifrostUnsupportedToolCallToAnthropicMessage converts unsupported tool calls to text messages func convertBifrostUnsupportedToolCallToAnthropicMessage(msg *schemas.ResponsesMessage, msgType schemas.ResponsesMessageType) *AnthropicMessage { if msg.ResponsesToolMessage != nil { @@ -3800,28 +3991,6 @@ func convertBifrostUnsupportedToolCallToAnthropicMessage(msg *schemas.ResponsesM return nil } -// convertBifrostComputerCallOutputToAnthropicMessage converts a Bifrost computer call output to Anthropic message -func convertBifrostComputerCallOutputToAnthropicMessage(msg *schemas.ResponsesMessage) *AnthropicMessage { - if msg.ResponsesToolMessage != nil { - toolResultBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeToolResult, - ToolUseID: msg.ResponsesToolMessage.CallID, - } - - if msg.ResponsesToolMessage.Output != nil { - toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) - } - - return &AnthropicMessage{ - Role: AnthropicMessageRoleUser, - Content: AnthropicContent{ - ContentBlocks: []AnthropicContentBlock{toolResultBlock}, - }, - } - } - return nil -} - // convertBifrostToolOutputToAnthropicMessage converts tool outputs to user messages func convertBifrostToolOutputToAnthropicMessage(msg *schemas.ResponsesMessage) *AnthropicMessage { if msg.ResponsesToolMessage != nil { @@ -3873,7 +4042,6 @@ func convertAnthropicToolToBifrost(tool *AnthropicTool) *schemas.ResponsesTool { } } return bifrostTool - case AnthropicToolTypeWebSearch20250305: bifrostTool := &schemas.ResponsesTool{ Type: schemas.ResponsesToolTypeWebSearch, @@ -3901,12 +4069,10 @@ func convertAnthropicToolToBifrost(tool *AnthropicTool) *schemas.ResponsesTool { bfToolJSON, _ := json.MarshalIndent(bifrostTool, "", " ") fmt.Println("bifrostTool", string(bfToolJSON)) return bifrostTool - case AnthropicToolTypeBash20250124: return &schemas.ResponsesTool{ Type: schemas.ResponsesToolTypeLocalShell, } - case AnthropicToolTypeTextEditor20250124: return &schemas.ResponsesTool{ Type: schemas.ResponsesToolType(AnthropicToolTypeTextEditor20250124), @@ -3922,6 +4088,11 @@ func convertAnthropicToolToBifrost(tool *AnthropicTool) *schemas.ResponsesTool { Type: schemas.ResponsesToolType(AnthropicToolTypeTextEditor20250728), Name: &tool.Name, } + case AnthropicToolTypeCodeExecution: + return &schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeCodeInterpreter, + Name: &tool.Name, + } } } @@ -3977,27 +4148,6 @@ func convertAnthropicToolChoiceToBifrost(toolChoice *AnthropicToolChoice) *schem return bifrostToolChoice } -// flushPendingContentBlocks is a helper that flushes accumulated content blocks into an assistant message -func flushPendingContentBlocks( - pendingContentBlocks []AnthropicContentBlock, - currentAssistantMessage *AnthropicMessage, - anthropicMessages []AnthropicMessage, -) ([]AnthropicContentBlock, *AnthropicMessage, []AnthropicMessage) { - if len(pendingContentBlocks) > 0 && currentAssistantMessage != nil { - // Copy the slice to avoid aliasing issues - copied := make([]AnthropicContentBlock, len(pendingContentBlocks)) - copy(copied, pendingContentBlocks) - currentAssistantMessage.Content = AnthropicContent{ - ContentBlocks: copied, - } - anthropicMessages = append(anthropicMessages, *currentAssistantMessage) - // Return nil values to indicate flushed state - return nil, nil, anthropicMessages - } - // Return unchanged values if no flush was needed - return pendingContentBlocks, currentAssistantMessage, anthropicMessages -} - // convertToolOutputToAnthropicContent converts tool output to Anthropic content format func convertToolOutputToAnthropicContent(output *schemas.ResponsesToolMessageOutputStruct) *AnthropicContent { if output == nil { @@ -4093,6 +4243,11 @@ func convertBifrostToolToAnthropic(model string, tool *schemas.ResponsesTool) *A Type: schemas.Ptr(AnthropicToolTypeBash20250124), Name: string(AnthropicToolNameBash), } + case schemas.ResponsesToolTypeCodeInterpreter: + return &AnthropicTool{ + Type: schemas.Ptr(AnthropicToolTypeCodeExecution), + Name: string(AnthropicToolNameCodeExecution), + } case schemas.ResponsesToolType(AnthropicToolTypeTextEditor20250124): return &AnthropicTool{ Type: schemas.Ptr(AnthropicToolTypeTextEditor20250124), diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index 12e6fef1e..698d6d89e 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -158,23 +158,27 @@ type AnthropicMessage struct { type AnthropicContent struct { ContentStr *string ContentBlocks []AnthropicContentBlock + ContentBlock *AnthropicContentBlock // For "bash_code_execution_tool_result" } // MarshalJSON implements custom JSON marshalling for AnthropicContent. // It marshals either ContentStr or ContentBlocks directly without wrapping. func (mc AnthropicContent) MarshalJSON() ([]byte, error) { // Validation: ensure only one field is set at a time - if mc.ContentStr != nil && mc.ContentBlocks != nil { - return nil, fmt.Errorf("both ContentStr and ContentBlocks are set; only one should be non-nil") + if mc.ContentStr != nil && mc.ContentBlocks != nil && mc.ContentBlock != nil { + return nil, fmt.Errorf("both ContentStr, ContentBlocks and ContentBlock are set; only one should be non-nil") } if mc.ContentStr != nil { return sonic.Marshal(*mc.ContentStr) } - if mc.ContentBlocks != nil { + if mc.ContentBlock != nil && mc.ContentBlocks == nil { + return sonic.Marshal(*mc.ContentBlock) + } + if mc.ContentBlocks != nil && mc.ContentBlock == nil { return sonic.Marshal(mc.ContentBlocks) } - // If both are nil, return null + // If all are nil, return null return sonic.Marshal(nil) } @@ -188,6 +192,13 @@ func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { return nil } + // Try to unmarshal as a direct ContentBlock + var contentBlock AnthropicContentBlock + if err := sonic.Unmarshal(data, &contentBlock); err == nil { + mc.ContentBlock = &contentBlock + return nil + } + // Try to unmarshal as a direct array of ContentBlock var arrayContent []AnthropicContentBlock if err := sonic.Unmarshal(data, &arrayContent); err == nil { @@ -195,24 +206,26 @@ func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { return nil } - return fmt.Errorf("content field is neither a string nor an array of ContentBlock") + return fmt.Errorf("content field is neither a string, ContentBlock nor an array of ContentBlock") } type AnthropicContentBlockType string const ( - AnthropicContentBlockTypeText AnthropicContentBlockType = "text" - AnthropicContentBlockTypeImage AnthropicContentBlockType = "image" - AnthropicContentBlockTypeDocument AnthropicContentBlockType = "document" - AnthropicContentBlockTypeToolUse AnthropicContentBlockType = "tool_use" - AnthropicContentBlockTypeServerToolUse AnthropicContentBlockType = "server_tool_use" - AnthropicContentBlockTypeToolResult AnthropicContentBlockType = "tool_result" - AnthropicContentBlockTypeWebSearchToolResult AnthropicContentBlockType = "web_search_tool_result" - AnthropicContentBlockTypeWebSearchResult AnthropicContentBlockType = "web_search_result" - AnthropicContentBlockTypeMCPToolUse AnthropicContentBlockType = "mcp_tool_use" - AnthropicContentBlockTypeMCPToolResult AnthropicContentBlockType = "mcp_tool_result" - AnthropicContentBlockTypeThinking AnthropicContentBlockType = "thinking" - AnthropicContentBlockTypeRedactedThinking AnthropicContentBlockType = "redacted_thinking" + AnthropicContentBlockTypeText AnthropicContentBlockType = "text" + AnthropicContentBlockTypeImage AnthropicContentBlockType = "image" + AnthropicContentBlockTypeDocument AnthropicContentBlockType = "document" + AnthropicContentBlockTypeToolUse AnthropicContentBlockType = "tool_use" + AnthropicContentBlockTypeServerToolUse AnthropicContentBlockType = "server_tool_use" + AnthropicContentBlockTypeToolResult AnthropicContentBlockType = "tool_result" + AnthropicContentBlockTypeWebSearchToolResult AnthropicContentBlockType = "web_search_tool_result" + AnthropicContentBlockTypeWebSearchResult AnthropicContentBlockType = "web_search_result" + AnthropicContentBlockTypeMCPToolUse AnthropicContentBlockType = "mcp_tool_use" + AnthropicContentBlockTypeMCPToolResult AnthropicContentBlockType = "mcp_tool_result" + AnthropicContentBlockTypeBashCodeExecutionToolResult AnthropicContentBlockType = "bash_code_execution_tool_result" + AnthropicContentBlockTypeBashCodeExecutionResult AnthropicContentBlockType = "bash_code_execution_result" + AnthropicContentBlockTypeThinking AnthropicContentBlockType = "thinking" + AnthropicContentBlockTypeRedactedThinking AnthropicContentBlockType = "redacted_thinking" ) // AnthropicContentBlock represents content in Anthropic message format @@ -236,6 +249,9 @@ type AnthropicContentBlock struct { URL *string `json:"url,omitempty"` // For web_search_result content EncryptedContent *string `json:"encrypted_content,omitempty"` // For web_search_result content PageAge *string `json:"page_age,omitempty"` // For web_search_result content + StdOut *string `json:"stdout,omitempty"` // For bash_code_execution_result content + StdErr *string `json:"stderr,omitempty"` // For bash_code_execution_result content + ReturnCode *int `json:"return_code,omitempty"` // For bash_code_execution_result content } // AnthropicSource represents image or document source in Anthropic format @@ -360,10 +376,12 @@ const ( type AnthropicToolName string const ( - AnthropicToolNameComputer AnthropicToolName = "computer" - AnthropicToolNameWebSearch AnthropicToolName = "web_search" - AnthropicToolNameBash AnthropicToolName = "bash" - AnthropicToolNameTextEditor AnthropicToolName = "str_replace_based_edit_tool" + AnthropicToolNameComputer AnthropicToolName = "computer" + AnthropicToolNameWebSearch AnthropicToolName = "web_search" + AnthropicToolNameBash AnthropicToolName = "bash" + AnthropicToolNameTextEditor AnthropicToolName = "str_replace_based_edit_tool" + AnthropicToolNameBashCodeExecution AnthropicToolName = "bash_code_execution" + AnthropicToolNameCodeExecution AnthropicToolName = "code_execution" ) type AnthropicToolComputerUse struct { @@ -451,6 +469,7 @@ type AnthropicMessageResponse struct { Model string `json:"model"` StopReason AnthropicStopReason `json:"stop_reason,omitempty"` StopSequence *string `json:"stop_sequence,omitempty"` + Container *AnthropicContainer `json:"container,omitempty"` Usage *AnthropicUsage `json:"usage,omitempty"` } @@ -475,6 +494,11 @@ type AnthropicUsage struct { OutputTokens int `json:"output_tokens"` } +type AnthropicContainer struct { + ID string `json:"id"` + ExpiresAt string `json:"expires_at"` // ISO 8601 timestamp when the container expires (sent by Anthropic) +} + type AnthropicUsageCacheCreation struct { Ephemeral5mInputTokens int `json:"ephemeral_5m_input_tokens"` Ephemeral1hInputTokens int `json:"ephemeral_1h_input_tokens"` diff --git a/core/providers/gemini/count_tokens.go b/core/providers/gemini/count_tokens.go index 24c072137..7855fd826 100644 --- a/core/providers/gemini/count_tokens.go +++ b/core/providers/gemini/count_tokens.go @@ -16,16 +16,32 @@ func (resp *GeminiCountTokensResponse) ToBifrostCountTokensResponse(model string inputTokens := 0 inputDetails := &schemas.ResponsesResponseInputTokens{} - for _, m := range resp.PromptTokensDetails { - if m == nil { - continue - } - inputTokens += int(m.TokenCount) - mod := strings.ToLower(m.Modality) - // handle audio modality - if strings.Contains(mod, "audio") { - inputDetails.AudioTokens += int(m.TokenCount) + // Convert PromptTokensDetails to ModalityTokenCount + if len(resp.PromptTokensDetails) > 0 { + modalityDetails := make([]schemas.ModalityTokenCount, 0, len(resp.PromptTokensDetails)) + for _, m := range resp.PromptTokensDetails { + if m == nil { + continue + } + inputTokens += int(m.TokenCount) + mod := strings.ToLower(m.Modality) + + // Add to modality token count + modalityDetails = append(modalityDetails, schemas.ModalityTokenCount{ + Modality: m.Modality, + TokenCount: int(m.TokenCount), + }) + + // Also populate specific fields for common modalities + if strings.Contains(mod, "audio") { + inputDetails.AudioTokens += int(m.TokenCount) + } else if strings.Contains(mod, "text") { + inputDetails.TextTokens += int(m.TokenCount) + } else if strings.Contains(mod, "image") { + inputDetails.ImageTokens += int(m.TokenCount) + } } + inputDetails.ModalityTokenCount = modalityDetails } // Set cached tokens from top-level field if present diff --git a/core/providers/gemini/responses.go b/core/providers/gemini/responses.go index 4db3ef904..e9f8f9037 100644 --- a/core/providers/gemini/responses.go +++ b/core/providers/gemini/responses.go @@ -146,7 +146,12 @@ func (response *GenerateContentResponse) ToResponsesBifrostResponsesResponse() * Model: response.ModelVersion, } - // Convert usage information + // Map Gemini's responseId to the standard ID field + if response.ResponseID != "" { + bifrostResp.ID = &response.ResponseID + } + + // Convert usage information with extended token details bifrostResp.Usage = convertGeminiUsageMetadataToResponsesUsage(response.UsageMetadata) // Convert candidates to Responses output messages @@ -160,6 +165,16 @@ func (response *GenerateContentResponse) ToResponsesBifrostResponsesResponse() * return bifrostResp } +// ToGeminiResponsesResponse converts a BifrostResponsesResponse back to Gemini's GenerateContentResponse format. +// This is the reverse transformation of ToResponsesBifrostResponsesResponse, used when returning responses +// to clients that expect Gemini's native format. +// +// Key conversion rules: +// 1. Code execution: ResponsesMessageTypeCodeInterpreterCall messages are converted to executableCode + codeExecutionResult parts +// 2. Thought signatures: Reasoning messages with EncryptedContent are attached to adjacent parts (executableCode, functionCall, or text) +// rather than being emitted as standalone thoughtSignature parts +// 3. Function calls: ResponsesMessageTypeFunctionCall messages become functionCall parts with signatures preserved +// 4. Text content: Content blocks with Signature fields get their signatures attached to the text parts func ToGeminiResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) *GenerateContentResponse { if bifrostResp == nil { return nil @@ -169,7 +184,7 @@ func ToGeminiResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) *G ModelVersion: bifrostResp.Model, } - // Set response ID if available + // Map the standard ID field to Gemini's responseId if bifrostResp.ID != nil { geminiResp.ResponseID = *bifrostResp.ID } @@ -317,6 +332,70 @@ func ToGeminiResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) *G }) } + // Handle code interpreter calls (code execution) + if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeCodeInterpreterCall && msg.ResponsesToolMessage != nil { + if msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall != nil { + codeInterpreter := msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall + + // Look back to see if the previous message is a reasoning message with encrypted content + // If so, extract the thought signature and mark it as consumed + var thoughtSig []byte + if i > 0 { + prevMsg := bifrostResp.Output[i-1] + if prevMsg.Type != nil && *prevMsg.Type == schemas.ResponsesMessageTypeReasoning && + prevMsg.ResponsesReasoning != nil && prevMsg.ResponsesReasoning.EncryptedContent != nil { + decodedSig, err := base64.StdEncoding.DecodeString(*prevMsg.ResponsesReasoning.EncryptedContent) + if err == nil { + thoughtSig = decodedSig + // Mark the previous reasoning message as consumed + consumedIndices[i-1] = true + } + } + } + + // Create the ExecutableCode part with thought signature attached + if codeInterpreter.Code != nil && codeInterpreter.Language != nil { + executablePart := &Part{ + ExecutableCode: &ExecutableCode{ + Language: *codeInterpreter.Language, + Code: *codeInterpreter.Code, + }, + } + + // Attach thought signature to the executableCode part + if len(thoughtSig) > 0 { + executablePart.ThoughtSignature = thoughtSig + } + + currentParts = append(currentParts, executablePart) + } + + // Add CodeExecutionResult parts for each output + if len(codeInterpreter.Outputs) > 0 { + for _, output := range codeInterpreter.Outputs { + if output.ResponsesCodeInterpreterOutputLogs != nil { + logs := output.ResponsesCodeInterpreterOutputLogs + + // Map return code to Gemini outcome + outcome := OutcomeOK + if logs.ReturnCode != nil && *logs.ReturnCode != 0 { + outcome = OutcomeFailed + } + + resultPart := &Part{ + CodeExecutionResult: &CodeExecutionResult{ + Outcome: outcome, + Output: logs.Logs, + }, + } + + currentParts = append(currentParts, resultPart) + } + } + } + } + } + // Handle reasoning messages if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeReasoning && msg.ResponsesReasoning != nil { // Skip this reasoning message if it was already consumed as a thought signature @@ -335,14 +414,10 @@ func ToGeminiResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) *G } } } - if msg.ResponsesReasoning.EncryptedContent != nil { - decodedSig, err := base64.StdEncoding.DecodeString(*msg.ResponsesReasoning.EncryptedContent) - if err == nil { - currentParts = append(currentParts, &Part{ - ThoughtSignature: decodedSig, - }) - } - } + + // Standalone thoughtSignature without summary should be attached to adjacent parts + // Don't create a standalone part - this will be handled by looking ahead/behind + // in the code execution and function call handlers above } } @@ -383,8 +458,52 @@ func ToGeminiResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) *G CandidatesTokenCount: int32(bifrostResp.Usage.OutputTokens), TotalTokenCount: int32(bifrostResp.Usage.TotalTokens), } + + // Add thoughts token count if bifrostResp.Usage.OutputTokensDetails != nil { geminiResp.UsageMetadata.ThoughtsTokenCount = int32(bifrostResp.Usage.OutputTokensDetails.ReasoningTokens) + + // Add candidates tokens details (output modality breakdown) + if len(bifrostResp.Usage.OutputTokensDetails.ModalityTokenCount) > 0 { + details := make([]*ModalityTokenCount, len(bifrostResp.Usage.OutputTokensDetails.ModalityTokenCount)) + for i, detail := range bifrostResp.Usage.OutputTokensDetails.ModalityTokenCount { + details[i] = &ModalityTokenCount{ + Modality: detail.Modality, + TokenCount: int32(detail.TokenCount), + } + } + geminiResp.UsageMetadata.CandidatesTokensDetails = details + } + } + + // Add input tokens details + if bifrostResp.Usage.InputTokensDetails != nil { + // Add cached tokens + if bifrostResp.Usage.InputTokensDetails.CachedTokens > 0 { + geminiResp.UsageMetadata.CachedContentTokenCount = int32(bifrostResp.Usage.InputTokensDetails.CachedTokens) + } + + // Add tool use prompt token count + if bifrostResp.Usage.InputTokensDetails.ToolUseTokens > 0 { + geminiResp.UsageMetadata.ToolUsePromptTokenCount = int32(bifrostResp.Usage.InputTokensDetails.ToolUseTokens) + } + + // Add prompt tokens details (modality breakdown) + if len(bifrostResp.Usage.InputTokensDetails.ModalityTokenCount) > 0 { + // Split modality details between tool use and prompt + // For now, we'll put them all in PromptTokensDetails since we can't easily distinguish + details := make([]*ModalityTokenCount, len(bifrostResp.Usage.InputTokensDetails.ModalityTokenCount)) + for i, detail := range bifrostResp.Usage.InputTokensDetails.ModalityTokenCount { + details[i] = &ModalityTokenCount{ + Modality: detail.Modality, + TokenCount: int32(detail.TokenCount), + } + } + geminiResp.UsageMetadata.PromptTokensDetails = details + // Also set ToolUsePromptTokensDetails to the same value + // Gemini might use either field depending on context + geminiResp.UsageMetadata.ToolUsePromptTokensDetails = details + } } } @@ -1759,6 +1878,14 @@ func convertGeminiToolsToResponsesTools(tools []Tool) []schemas.ResponsesTool { responsesTools = append(responsesTools, responsesTool) } } + + // Handle CodeExecution tool + if tool.CodeExecution != nil { + responsesTool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeCodeInterpreter, + } + responsesTools = append(responsesTools, responsesTool) + } } return responsesTools @@ -1806,7 +1933,7 @@ func convertGeminiCandidatesToResponsesOutput(candidates []*Candidate) []schemas continue } - for _, part := range candidate.Content.Parts { + for partIdx, part := range candidate.Content.Parts { // Handle different types of parts switch { case part.Thought: @@ -1983,44 +2110,90 @@ func convertGeminiCandidatesToResponsesOutput(candidates []*Candidate) []schemas } messages = append(messages, msg) - case part.CodeExecutionResult != nil: - // Handle code execution results - output := part.CodeExecutionResult.Output - if part.CodeExecutionResult.Outcome != OutcomeOK { - output = "Error: " + output + case part.ExecutableCode != nil: + // Handle executable code - create a code_interpreter_call message + // Generate a unique ID for this code execution + codeID := fmt.Sprintf("code_%s", providerUtils.GetRandomString(32)) + + // If there's a thought signature, create a reasoning message BEFORE the code_interpreter_call + // This preserves the order so we can reconstruct it correctly + if part.ThoughtSignature != nil { + thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature) + reasoningMsg := schemas.ResponsesMessage{ + ID: schemas.Ptr("rs_" + providerUtils.GetRandomString(32)), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + EncryptedContent: &thoughtSig, + }, + } + messages = append(messages, reasoningMsg) } msg := schemas.ResponsesMessage{ - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: &output, - }, + ID: &codeID, + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Type: schemas.Ptr(schemas.ResponsesMessageTypeCodeInterpreterCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &codeID, + ResponsesCodeInterpreterToolCall: &schemas.ResponsesCodeInterpreterToolCall{ + Code: &part.ExecutableCode.Code, + Language: &part.ExecutableCode.Language, + ContainerID: "", // Gemini doesn't use containers, so we use an empty string + Outputs: []schemas.ResponsesCodeInterpreterOutput{}, }, }, - Type: schemas.Ptr(schemas.ResponsesMessageTypeCodeInterpreterCall), } - messages = append(messages, msg) - case part.ExecutableCode != nil: - // Handle executable code - codeContent := "```" + part.ExecutableCode.Language + "\n" + part.ExecutableCode.Code + "\n```" + // Look ahead to find the corresponding CodeExecutionResult + // It should be in the next part or parts + for j := partIdx + 1; j < len(candidate.Content.Parts); j++ { + nextPart := candidate.Content.Parts[j] + if nextPart.CodeExecutionResult != nil { + // Add the execution result as output + var log schemas.ResponsesCodeInterpreterOutputLogs + log.Type = "logs" + + // Map Gemini outcome to return code + var returnCode int + switch nextPart.CodeExecutionResult.Outcome { + case OutcomeOK: + returnCode = 0 + log.Logs = nextPart.CodeExecutionResult.Output + case OutcomeFailed: + returnCode = 1 + log.Logs = nextPart.CodeExecutionResult.Output + case OutcomeDeadlineExceeded: + returnCode = 124 // Standard timeout exit code + log.Logs = nextPart.CodeExecutionResult.Output + default: + returnCode = 1 + log.Logs = nextPart.CodeExecutionResult.Output + } + log.ReturnCode = &returnCode - msg := schemas.ResponsesMessage{ - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: &codeContent, + msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Outputs = append( + msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Outputs, + schemas.ResponsesCodeInterpreterOutput{ + ResponsesCodeInterpreterOutputLogs: &log, }, - }, - }, - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + ) + break + } + // Stop looking if we hit another executable code or other significant content + if nextPart.ExecutableCode != nil || nextPart.Text != "" || nextPart.FunctionCall != nil { + break + } } + messages = append(messages, msg) + + case part.CodeExecutionResult != nil: + // Skip standalone CodeExecutionResult - it should already be handled with ExecutableCode + // This case is here to prevent it from being treated as unknown content + continue case part.ThoughtSignature != nil: // Handle thought signature thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature) @@ -2242,10 +2415,14 @@ func convertResponsesToolsToGemini(tools []schemas.ResponsesTool) []Tool { geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, funcDecl) } } + } else if tool.Type == schemas.ResponsesToolTypeCodeInterpreter { + // Add CodeExecution tool + geminiTool.CodeExecution = &ToolCodeExecution{} } } - if len(geminiTool.FunctionDeclarations) > 0 { + // Return the tool if it has any declarations or code execution + if len(geminiTool.FunctionDeclarations) > 0 || geminiTool.CodeExecution != nil { return []Tool{geminiTool} } return []Tool{} @@ -2375,6 +2552,65 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag // Handle tool calls from assistant messages if msg.ResponsesToolMessage != nil && msg.Type != nil { switch *msg.Type { + case schemas.ResponsesMessageTypeCodeInterpreterCall: + // Convert code_interpreter_call to Gemini ExecutableCode + CodeExecutionResult + if msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall != nil { + codeInterpreter := msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall + + var thoughtSig []byte + // Look back to see if the previous message is a reasoning message with encrypted content + if i > 0 { + prevMsg := messages[i-1] + if prevMsg.Type != nil && *prevMsg.Type == schemas.ResponsesMessageTypeReasoning && + prevMsg.ResponsesReasoning != nil && prevMsg.ResponsesReasoning.EncryptedContent != nil { + decodedSig, err := base64.StdEncoding.DecodeString(*prevMsg.ResponsesReasoning.EncryptedContent) + if err == nil { + thoughtSig = decodedSig + } + } + } + + // Add ExecutableCode part with thought signature attached + if codeInterpreter.Code != nil && codeInterpreter.Language != nil { + executablePart := &Part{ + ExecutableCode: &ExecutableCode{ + Language: *codeInterpreter.Language, + Code: *codeInterpreter.Code, + }, + } + + // Attach thought signature to the executableCode part only + if len(thoughtSig) > 0 { + executablePart.ThoughtSignature = thoughtSig + } + + content.Parts = append(content.Parts, executablePart) + } + + // Add CodeExecutionResult parts for each output (without thoughtSignature) + if len(codeInterpreter.Outputs) > 0 { + for _, output := range codeInterpreter.Outputs { + if output.ResponsesCodeInterpreterOutputLogs != nil { + logs := output.ResponsesCodeInterpreterOutputLogs + + // Map return code to Gemini outcome + outcome := OutcomeOK + if logs.ReturnCode != nil && *logs.ReturnCode != 0 { + outcome = OutcomeFailed + } + + resultPart := &Part{ + CodeExecutionResult: &CodeExecutionResult{ + Outcome: outcome, + Output: logs.Logs, + }, + } + + content.Parts = append(content.Parts, resultPart) + } + } + } + } case schemas.ResponsesMessageTypeFunctionCall: // Convert function call to Gemini FunctionCall if msg.ResponsesToolMessage.Name != nil { diff --git a/core/providers/gemini/responses_test.go b/core/providers/gemini/responses_test.go new file mode 100644 index 000000000..0ac1cdd55 --- /dev/null +++ b/core/providers/gemini/responses_test.go @@ -0,0 +1,254 @@ +package gemini + +import ( + "encoding/base64" + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" +) + +// TestCodeExecutionRoundTrip tests that code execution with thought signatures +// is correctly converted from Gemini -> Bifrost -> Gemini +func TestCodeExecutionRoundTrip(t *testing.T) { + // Create a raw Gemini response with code execution + thoughtSig1 := []byte("test-thought-signature-1") + thoughtSig2 := []byte("test-thought-signature-2") + + originalResponse := &GenerateContentResponse{ + ModelVersion: "gemini-3-flash-preview", + ResponseID: "test-response-id", + Candidates: []*Candidate{ + { + Index: 0, + Content: &Content{ + Parts: []*Part{ + { + ExecutableCode: &ExecutableCode{ + Language: "PYTHON", + Code: "def is_prime(n):\n if n < 2:\n return False\n return True\n", + }, + ThoughtSignature: thoughtSig1, + }, + { + CodeExecutionResult: &CodeExecutionResult{ + Outcome: OutcomeOK, + Output: "Primes: [2, 3, 5, 7]\nSum: 17\n", + }, + }, + { + Text: "The sum of the first 4 primes is 17.", + ThoughtSignature: thoughtSig2, + }, + }, + Role: "model", + }, + FinishReason: FinishReasonStop, + }, + }, + UsageMetadata: &GenerateContentResponseUsageMetadata{ + ThoughtsTokenCount: 100, + PromptTokenCount: 50, + CandidatesTokenCount: 150, + TotalTokenCount: 300, + }, + } + + // Step 1: Convert Gemini -> Bifrost + bifrostResp := originalResponse.ToResponsesBifrostResponsesResponse() + assert.NotNil(t, bifrostResp) + assert.NotNil(t, bifrostResp.Output) + + // Verify we have the expected messages: + // - A reasoning message with encrypted content (thought signature for code) + // - A code_interpreter_call message + // - A message with text content (with thought signature in the content block) + foundReasoning := false + foundCodeInterpreter := false + foundText := false + + for _, msg := range bifrostResp.Output { + if msg.Type != nil { + switch *msg.Type { + case schemas.ResponsesMessageTypeReasoning: + if msg.ResponsesReasoning != nil && msg.ResponsesReasoning.EncryptedContent != nil { + foundReasoning = true + // Verify the thought signature matches + decoded, err := base64.StdEncoding.DecodeString(*msg.ResponsesReasoning.EncryptedContent) + assert.NoError(t, err) + assert.Equal(t, thoughtSig1, decoded) + } + case schemas.ResponsesMessageTypeCodeInterpreterCall: + foundCodeInterpreter = true + assert.NotNil(t, msg.ResponsesToolMessage) + assert.NotNil(t, msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall) + assert.NotNil(t, msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Code) + assert.Contains(t, *msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Code, "is_prime") + // Verify outputs + assert.Len(t, msg.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Outputs, 1) + case schemas.ResponsesMessageTypeMessage: + if msg.Content != nil && len(msg.Content.ContentBlocks) > 0 { + for _, block := range msg.Content.ContentBlocks { + if block.Text != nil && *block.Text != "" { + foundText = true + // Verify the thought signature is in the content block + assert.NotNil(t, block.Signature) + decoded, err := base64.StdEncoding.DecodeString(*block.Signature) + assert.NoError(t, err) + assert.Equal(t, thoughtSig2, decoded) + } + } + } + } + } + } + + assert.True(t, foundReasoning, "Should have a reasoning message with thought signature") + assert.True(t, foundCodeInterpreter, "Should have a code_interpreter_call message") + assert.True(t, foundText, "Should have a text message with thought signature") + + // Step 2: Convert Bifrost -> Gemini + reconstructedResponse := ToGeminiResponsesResponse(bifrostResp) + assert.NotNil(t, reconstructedResponse) + assert.Len(t, reconstructedResponse.Candidates, 1) + + candidate := reconstructedResponse.Candidates[0] + assert.NotNil(t, candidate.Content) + assert.Len(t, candidate.Content.Parts, 3, "Should have exactly 3 parts: executableCode, codeExecutionResult, text") + + // Verify Part 0: ExecutableCode with thought signature + part0 := candidate.Content.Parts[0] + assert.NotNil(t, part0.ExecutableCode, "Part 0 should have ExecutableCode") + assert.Equal(t, "PYTHON", part0.ExecutableCode.Language) + assert.Contains(t, part0.ExecutableCode.Code, "is_prime") + assert.NotNil(t, part0.ThoughtSignature, "Part 0 should have ThoughtSignature attached to ExecutableCode") + assert.Equal(t, thoughtSig1, part0.ThoughtSignature) + // Verify it's NOT a standalone part with only thoughtSignature + assert.NotNil(t, part0.ExecutableCode, "Part 0 should have ExecutableCode, not be a standalone thoughtSignature") + + // Verify Part 1: CodeExecutionResult (no thought signature) + part1 := candidate.Content.Parts[1] + assert.NotNil(t, part1.CodeExecutionResult, "Part 1 should have CodeExecutionResult") + assert.Equal(t, OutcomeOK, part1.CodeExecutionResult.Outcome) + assert.Contains(t, part1.CodeExecutionResult.Output, "Primes:") + assert.Nil(t, part1.ThoughtSignature, "Part 1 should NOT have ThoughtSignature") + + // Verify Part 2: Text with thought signature + part2 := candidate.Content.Parts[2] + assert.NotEmpty(t, part2.Text, "Part 2 should have text") + assert.Contains(t, part2.Text, "The sum") + assert.NotNil(t, part2.ThoughtSignature, "Part 2 should have ThoughtSignature attached to text") + assert.Equal(t, thoughtSig2, part2.ThoughtSignature) + + // Verify usage metadata is preserved + assert.NotNil(t, reconstructedResponse.UsageMetadata) + assert.Equal(t, int32(100), reconstructedResponse.UsageMetadata.ThoughtsTokenCount) +} + +// TestRealGeminiCodeExecutionResponse tests with a real Gemini API response format +// This mimics the exact scenario from the user's bug report +func TestRealGeminiCodeExecutionResponse(t *testing.T) { + // This is based on the actual Gemini API response format from the user's example + responseJSON := `{ + "candidates": [ + { + "content": { + "parts": [ + { + "executableCode": { + "language": "PYTHON", + "code": "def is_prime(n):\n if n < 2:\n return False\n for i in range(2, int(n**0.5) + 1):\n if n % i == 0:\n return False\n return True\n\nprimes = []\nnum = 2\nwhile len(primes) < 50:\n if is_prime(num):\n primes.append(num)\n num += 1\n\nsum_primes = sum(primes)\nprint(f\"Primes: {primes}\")\nprint(f\"Sum: {sum_primes}\")\n" + }, + "thoughtSignature": "EuwECukEAXLI2nwr9f360fnlN/uEDL6wJ7+EwKhtt18hOp/oZCpTTasGUoESz9+xnYSaixBB2LB/EKwsSUctFq1IvE9uDimfhaDIGuwgCPMeNQq1lXOGbIygrQuJPCQsQZTk5WlK0FT1c3ZFlDC0uJwCSgDUGdfzD+wcJJ33hR0i6nt6XTQutza0CgmySerlgFUagXiVbP/9iTLVQnSmdT/VLtFIs0Ekf0StxfdV8jqank5MESI1qNR+YoMF+04IgkvTMvY8jDNgvmwBoslDbqtnDN0411bpWq5SVTQ+yv4m9RtVjUn8cz3kerdoUjSIf7d26XtXOuCDTu0HXkna7RX9Bovlwc1YM1hzmZ8sPqAj/Da7QfABH02//be/UUYGYXA10raamU1mkYeYK0hXpA/JG90Vpjbp9BIwLptnDGwZShwVg8m9zb5xmZZFC19XkzhH9hfcUlKos1TosCTWCfwEwkhs/8AuEhyZ/0/MQEvy4iZ7D3FmWXSTqM/8sjyKTPLn8WIJlFDnbQ4sUSPqd3qirQhTjgaeMWXHKnAFdWdukUYM7OIuQy913hhpLc0mI4mJc7wlTdMz+K9yx85uRquLTVci8V4+S0/iOxXxAxaGk1aklS99S5G4wj7MBT/hUNhPmxN7qvV2ujmmVJPi6cDKg4oonYhXwPWFjlvEPQedRMQAZD6edtl3Xm0N97YtsuHjny7TFQkL1Q2DgHF9LVFnyE7uTuGK/5y2HXeB40AAi5dLWyb+gGFu8mwXqxReVHZL27S0Q2qcmgBzlzhM4aMTodJWJ/jUIjQnWNLz2EJ8pkEb31lqwMONVTFJ8m4=" + }, + { + "codeExecutionResult": { + "outcome": "OUTCOME_OK", + "output": "Primes: [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229]\nSum: 5117\n" + } + }, + { + "text": "The sum of the first 50 prime numbers is **5,117**.", + "thoughtSignature": "Ev4CCvsCAXLI2nwD2FnAquYRanHjC6DTz+yxDNQ/SUTjvh+nD4HpPHMUcma/tTTHc9LQWVwjfPhUvZc44NzRh0sIq8m2Mdfpl+KyB/cQHl0gbPoqOnVk8Yh6R9EIF8EIx4J6OSAnU5lUq8l/6blwST7LzB3OeCgL4saVBs0dGIvv24LveNdWC6PbjUdSExAeyx+i7+JqLjzkcsV9m9S0hDZXImKpCTFydisY+/pFXM5R362KPGT1bCT2Jvs/i+SEkuMuPcXHabADPpCtnv7BOX+5RnD7A46n93peL2AlJnxkK1wdRuOEObABy/wGPbLqwRszUt7z0yzZS5igkH2KndcfuWzeo45OsMTMy+J461cLfGcJwBjVnmdgDNGWQROW4mTb6LIz1QeL0D/WUEHldWwbYgmGYf9RBzCeMGTBPPkkA3oAiYIgtsXmHuCAvNM2zkJUi9XZEQfRxBnyZNLsLIWZ/yY3FArZzDdHglHuYZX1FVi/1OEKlSTgPJPdAr5xxw==" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ], + "usageMetadata": { + "thoughtsTokenCount": 224, + "promptTokenCount": 194, + "candidatesTokenCount": 419, + "totalTokenCount": 1557 + }, + "modelVersion": "gemini-3-flash-preview", + "responseId": "jaVnaYS1MZGjqfkPpM2m-AQ" +}` + + // Parse the JSON into a Gemini response + var geminiResp GenerateContentResponse + err := json.Unmarshal([]byte(responseJSON), &geminiResp) + assert.NoError(t, err) + + // Verify the original response has the expected structure + assert.Len(t, geminiResp.Candidates, 1) + assert.Len(t, geminiResp.Candidates[0].Content.Parts, 3) + + // Part 0: executableCode with thoughtSignature + assert.NotNil(t, geminiResp.Candidates[0].Content.Parts[0].ExecutableCode) + assert.NotNil(t, geminiResp.Candidates[0].Content.Parts[0].ThoughtSignature) + + // Part 1: codeExecutionResult (no thoughtSignature) + assert.NotNil(t, geminiResp.Candidates[0].Content.Parts[1].CodeExecutionResult) + assert.Nil(t, geminiResp.Candidates[0].Content.Parts[1].ThoughtSignature) + + // Part 2: text with thoughtSignature + assert.NotEmpty(t, geminiResp.Candidates[0].Content.Parts[2].Text) + assert.NotNil(t, geminiResp.Candidates[0].Content.Parts[2].ThoughtSignature) + + // Convert to Bifrost format + bifrostResp := geminiResp.ToResponsesBifrostResponsesResponse() + assert.NotNil(t, bifrostResp) + + // Convert back to Gemini format + reconstructed := ToGeminiResponsesResponse(bifrostResp) + assert.NotNil(t, reconstructed) + assert.Len(t, reconstructed.Candidates, 1) + assert.Len(t, reconstructed.Candidates[0].Content.Parts, 3, "Should have exactly 3 parts") + + // Verify the reconstructed response matches the original structure + parts := reconstructed.Candidates[0].Content.Parts + + // Part 0: executableCode with thoughtSignature (NOT a standalone thoughtSignature part!) + assert.NotNil(t, parts[0].ExecutableCode, "Part 0 must have executableCode") + assert.Equal(t, "PYTHON", parts[0].ExecutableCode.Language) + assert.Contains(t, parts[0].ExecutableCode.Code, "is_prime") + assert.NotNil(t, parts[0].ThoughtSignature, "Part 0 must have thoughtSignature attached to executableCode") + assert.Equal(t, geminiResp.Candidates[0].Content.Parts[0].ThoughtSignature, parts[0].ThoughtSignature) + + // Part 1: codeExecutionResult (no thoughtSignature) + assert.NotNil(t, parts[1].CodeExecutionResult, "Part 1 must have codeExecutionResult") + assert.Equal(t, OutcomeOK, parts[1].CodeExecutionResult.Outcome) + assert.Contains(t, parts[1].CodeExecutionResult.Output, "Primes:") + assert.Nil(t, parts[1].ThoughtSignature, "Part 1 should NOT have thoughtSignature") + + // Part 2: text with thoughtSignature + assert.NotEmpty(t, parts[2].Text, "Part 2 must have text") + assert.Contains(t, parts[2].Text, "5,117") + assert.NotNil(t, parts[2].ThoughtSignature, "Part 2 must have thoughtSignature attached to text") + assert.Equal(t, geminiResp.Candidates[0].Content.Parts[2].ThoughtSignature, parts[2].ThoughtSignature) + + // Ensure NO standalone thoughtSignature parts were created + for i, part := range parts { + // A standalone thoughtSignature part would have ONLY thoughtSignature set, nothing else + if part.ThoughtSignature != nil { + // It should have either executableCode, text, or functionCall + hasContent := part.ExecutableCode != nil || part.Text != "" || part.FunctionCall != nil + assert.True(t, hasContent, "Part %d has thoughtSignature but no content - this is a standalone thoughtSignature part which is incorrect", i) + } + } +} diff --git a/core/providers/gemini/types.go b/core/providers/gemini/types.go index 77bbd568e..f536cde48 100644 --- a/core/providers/gemini/types.go +++ b/core/providers/gemini/types.go @@ -66,7 +66,7 @@ const ( type GeminiGenerationRequest struct { Model string `json:"model,omitempty"` // Model field for explicit model specification - Contents []Content `json:"contents,omitempty"` // For chat completion requests + Contents []Content `json:"-"` // For chat completion requests - handled by custom unmarshaller Requests []GeminiEmbeddingRequest `json:"requests,omitempty"` // For batch embedding requests SystemInstruction *Content `json:"systemInstruction,omitempty"` GenerationConfig GenerationConfig `json:"generationConfig,omitempty"` @@ -85,6 +85,52 @@ type GeminiGenerationRequest struct { Fallbacks []string `json:"fallbacks,omitempty"` } +// UnmarshalJSON implements custom JSON unmarshaling for GeminiGenerationRequest. +// This handles the contents field which can be either a single Content object or an array of Content objects. +func (g *GeminiGenerationRequest) UnmarshalJSON(data []byte) error { + type Alias GeminiGenerationRequest + aux := &struct { + Contents json.RawMessage `json:"contents,omitempty"` + *Alias + }{ + Alias: (*Alias)(g), + } + + if err := sonic.Unmarshal(data, &aux); err != nil { + return err + } + + // Handle contents field - can be single object or array + if len(aux.Contents) > 0 { + // Try to unmarshal as array first + var contentsArray []Content + if err := sonic.Unmarshal(aux.Contents, &contentsArray); err == nil { + g.Contents = contentsArray + } else { + // If that fails, try as single object + var singleContent Content + if err := sonic.Unmarshal(aux.Contents, &singleContent); err != nil { + return fmt.Errorf("contents must be either a Content object or array of Content objects: %w", err) + } + g.Contents = []Content{singleContent} + } + } + + return nil +} + +// MarshalJSON implements custom JSON marshaling for GeminiGenerationRequest. +func (g GeminiGenerationRequest) MarshalJSON() ([]byte, error) { + type Alias GeminiGenerationRequest + return sonic.Marshal(&struct { + Contents []Content `json:"contents,omitempty"` + *Alias + }{ + Contents: g.Contents, + Alias: (*Alias)(&g), + }) +} + // IsStreamingRequested implements the StreamingRequest interface func (r *GeminiGenerationRequest) IsStreamingRequested() bool { return r.Stream @@ -1016,13 +1062,59 @@ type GeminiEmbeddingRequest struct { type Content struct { // Optional. List of parts that constitute a single message. Each part may have // a different IANA MIME type. - Parts []*Part `json:"parts,omitempty"` + Parts []*Part `json:"-"` // Handled by custom unmarshaller // Optional. The producer of the content. Must be either 'user' or // 'model'. Useful to set for multi-turn conversations, otherwise can be // empty. If role is not specified, SDK will determine the role. Role string `json:"role,omitempty"` } +// UnmarshalJSON implements custom JSON unmarshaling for Content. +// This handles the parts field which can be either a single Part object or an array of Part objects. +func (c *Content) UnmarshalJSON(data []byte) error { + type Alias Content + aux := &struct { + Parts json.RawMessage `json:"parts,omitempty"` + *Alias + }{ + Alias: (*Alias)(c), + } + + if err := sonic.Unmarshal(data, &aux); err != nil { + return err + } + + // Handle parts field - can be single object or array + if len(aux.Parts) > 0 { + // Try to unmarshal as array first + var partsArray []*Part + if err := sonic.Unmarshal(aux.Parts, &partsArray); err == nil { + c.Parts = partsArray + } else { + // If that fails, try as single object + var singlePart Part + if err := sonic.Unmarshal(aux.Parts, &singlePart); err != nil { + return fmt.Errorf("parts must be either a Part object or array of Part objects: %w", err) + } + c.Parts = []*Part{&singlePart} + } + } + + return nil +} + +// MarshalJSON implements custom JSON marshaling for Content. +func (c Content) MarshalJSON() ([]byte, error) { + type Alias Content + return sonic.Marshal(&struct { + Parts []*Part `json:"parts,omitempty"` + *Alias + }{ + Parts: c.Parts, + Alias: (*Alias)(&c), + }) +} + // Part is a datatype containing media content. // Exactly one field within a Part should be set, representing the specific type // of content being conveyed. Using multiple fields within the same `Part` diff --git a/core/providers/gemini/utils.go b/core/providers/gemini/utils.go index 2aa652bed..63b2e87ad 100644 --- a/core/providers/gemini/utils.go +++ b/core/providers/gemini/utils.go @@ -423,17 +423,82 @@ func convertGeminiUsageMetadataToResponsesUsage(metadata *GenerateContentRespons // Add cached tokens if present if metadata.CachedContentTokenCount > 0 { - usage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ - CachedTokens: int(metadata.CachedContentTokenCount), + usage.InputTokensDetails.CachedTokens = int(metadata.CachedContentTokenCount) + } + + // Add tool use prompt token count if present + if metadata.ToolUsePromptTokenCount > 0 { + usage.InputTokensDetails.ToolUseTokens = int(metadata.ToolUsePromptTokenCount) + } + + // Add tool use prompt tokens details (modality breakdown) if present + if len(metadata.ToolUsePromptTokensDetails) > 0 { + details := make([]schemas.ModalityTokenCount, len(metadata.ToolUsePromptTokensDetails)) + for i, detail := range metadata.ToolUsePromptTokensDetails { + details[i] = schemas.ModalityTokenCount{ + Modality: detail.Modality, + TokenCount: int(detail.TokenCount), + } + } + usage.InputTokensDetails.ModalityTokenCount = details + } + + // Add prompt tokens details (modality breakdown) if present - merge with existing if needed + if len(metadata.PromptTokensDetails) > 0 { + // If we already have modality details from tool use, merge them + if len(usage.InputTokensDetails.ModalityTokenCount) > 0 { + // Create a map to merge modality counts + modalityMap := make(map[string]int) + for _, detail := range usage.InputTokensDetails.ModalityTokenCount { + modalityMap[detail.Modality] = detail.TokenCount + } + for _, detail := range metadata.PromptTokensDetails { + // Add or update the modality count + if existing, exists := modalityMap[detail.Modality]; exists { + modalityMap[detail.Modality] = existing + int(detail.TokenCount) + } else { + modalityMap[detail.Modality] = int(detail.TokenCount) + } + } + // Convert back to slice + details := make([]schemas.ModalityTokenCount, 0, len(modalityMap)) + for modality, count := range modalityMap { + details = append(details, schemas.ModalityTokenCount{ + Modality: modality, + TokenCount: count, + }) + } + usage.InputTokensDetails.ModalityTokenCount = details + } else { + // No existing modality details, just set from PromptTokensDetails + details := make([]schemas.ModalityTokenCount, len(metadata.PromptTokensDetails)) + for i, detail := range metadata.PromptTokensDetails { + details[i] = schemas.ModalityTokenCount{ + Modality: detail.Modality, + TokenCount: int(detail.TokenCount), + } + } + usage.InputTokensDetails.ModalityTokenCount = details } } + // Add output tokens details if metadata.CandidatesTokensDetails != nil { + outputDetails := make([]schemas.ModalityTokenCount, 0) for _, detail := range metadata.CandidatesTokensDetails { switch detail.Modality { case "AUDIO": usage.OutputTokensDetails.AudioTokens = int(detail.TokenCount) + case "TEXT": + usage.OutputTokensDetails.TextTokens = int(detail.TokenCount) } + outputDetails = append(outputDetails, schemas.ModalityTokenCount{ + Modality: detail.Modality, + TokenCount: int(detail.TokenCount), + }) + } + if len(outputDetails) > 0 { + usage.OutputTokensDetails.ModalityTokenCount = outputDetails } } diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index 9c204a572..749a471af 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -389,16 +389,6 @@ func HandleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif } } - // Try to unmarshal decoded body for RawResponse - var rawErrorResponse interface{} - if err := sonic.Unmarshal(decodedBody, &rawErrorResponse); err != nil { - if logger != nil { - logger.Warn(fmt.Sprintf("Failed to parse raw error response: %v", err)) - } - // If unmarshal fails (e.g., for HTML or plain text), store as string so RawResponse is never nil - rawErrorResponse = string(decodedBody) - } - // Check for empty response trimmed := strings.TrimSpace(string(decodedBody)) if len(trimmed) == 0 { @@ -409,11 +399,18 @@ func HandleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif Message: schemas.ErrProviderResponseEmpty, }, ExtraFields: schemas.BifrostErrorExtraFields{ - RawResponse: rawErrorResponse, + RawResponse: nil, // No raw response for empty response }, } } + // Try to unmarshal decoded body for RawResponse + var rawErrorResponse interface{} + if err := sonic.Unmarshal(decodedBody, &rawErrorResponse); err != nil { + // If unmarshal fails (e.g., for HTML or plain text), store as string so RawResponse is never nil + rawErrorResponse = string(decodedBody) + } + // Try JSON parsing first if err := sonic.Unmarshal(decodedBody, errorResp); err == nil { // JSON parsing succeeded, return success diff --git a/core/schemas/responses.go b/core/schemas/responses.go index 802e83de7..886aceffc 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -45,7 +45,8 @@ type BifrostResponsesResponse struct { Background *bool `json:"background,omitempty"` Conversation *ResponsesResponseConversation `json:"conversation,omitempty"` - CreatedAt int `json:"created_at"` // Unix timestamp when Response was created + CreatedAt int `json:"created_at"` // Unix timestamp when Response was created + CompletedAt *int `json:"completed_at"` // Unix timestamp when Response was completed Error *ResponsesResponseError `json:"error,omitempty"` Include []string `json:"include,omitempty"` // Supported values: "web_search_call.action.sources", "code_interpreter_call.outputs", "computer_call_output.output.image_url", "file_search_call.results", "message.input_image.image_url", "message.output_text.logprobs", "reasoning.encrypted_content" IncompleteDetails *ResponsesResponseIncompleteDetails `json:"incomplete_details,omitempty"` // Details about why the response is incomplete @@ -263,15 +264,23 @@ type ResponsesResponseUsage struct { } type ResponsesResponseInputTokens struct { - TextTokens int `json:"text_tokens,omitempty"` // Tokens for text input - AudioTokens int `json:"audio_tokens,omitempty"` // Tokens for audio input - ImageTokens int `json:"image_tokens,omitempty"` // Tokens for image input + TextTokens int `json:"text_tokens,omitempty"` // Tokens for text input + AudioTokens int `json:"audio_tokens,omitempty"` // Tokens for audio input + ImageTokens int `json:"image_tokens,omitempty"` // Tokens for image input + ToolUseTokens int `json:"tool_use_tokens,omitempty"` // Tokens for tool use input (set by gemini) + + ModalityTokenCount []ModalityTokenCount `json:"modality_token_count,omitempty"` // Detailed breakdown of input tokens by modality (set by gemini) // For Providers which follow OpenAI's spec, CachedTokens means the number of input tokens read from the cache+input tokens used to create the cache entry. (because they do not differentiate between cache creation and cache read tokens) // For Providers which do not follow OpenAI's spec, CachedTokens means only the number of input tokens read from the cache. CachedTokens int `json:"cached_tokens,omitempty"` } +type ModalityTokenCount struct { + Modality string `json:"modality"` + TokenCount int `json:"token_count"` +} + type ResponsesResponseOutputTokens struct { TextTokens int `json:"text_tokens,omitempty"` AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` @@ -280,6 +289,9 @@ type ResponsesResponseOutputTokens struct { RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` CitationTokens *int `json:"citation_tokens,omitempty"` NumSearchQueries *int `json:"num_search_queries,omitempty"` + ToolUseTokens int `json:"tool_use_tokens,omitempty"` // Tokens for tool use input (set by gemini) + + ModalityTokenCount []ModalityTokenCount `json:"modality_token_count,omitempty"` // Detailed breakdown of input tokens by modality (set by gemini) // This means the number of input tokens used to create the cache entry. (cache creation tokens) CachedTokens int `json:"cached_tokens,omitempty"` // Not in OpenAI's schemas, but sent by a few providers (Anthropic, Bedrock are some of them) @@ -795,9 +807,11 @@ type ResponsesImageGenerationCall struct { // ResponsesCodeInterpreterToolCall represents a code interpreter tool call type ResponsesCodeInterpreterToolCall struct { - Code *string `json:"code"` // The code to run, or null if not available - ContainerID string `json:"container_id"` // The ID of the container used to run the code - Outputs []ResponsesCodeInterpreterOutput `json:"outputs"` // The outputs generated by the code interpreter, can be null + Code *string `json:"code"` // The code to run, or null if not available + ContainerID string `json:"container_id"` // The ID of the container used to run the code + Outputs []ResponsesCodeInterpreterOutput `json:"outputs"` // The outputs generated by the code interpreter, can be null + ExpiresAt *string `json:"expires_at"` // ISO 8601 timestamp when the container expires (sent by Anthropic) + Language *string `json:"language,omitempty"` // The language of the code (sent by Gemini) } // ResponsesCodeInterpreterOutput represents a code interpreter output @@ -867,8 +881,9 @@ func (o *ResponsesCodeInterpreterOutput) UnmarshalJSON(data []byte) error { // ResponsesCodeInterpreterOutputLogs represents the logs output from the code interpreter type ResponsesCodeInterpreterOutputLogs struct { - Logs string `json:"logs"` - Type string `json:"type"` // always "logs" + Type string `json:"type"` // always "logs" + Logs string `json:"logs"` + ReturnCode *int `json:"return_code,omitempty"` // sent by Anthropic } // ResponsesCodeInterpreterOutputImage represents the image output from the code interpreter @@ -1346,7 +1361,44 @@ type ResponsesToolMCPAllowedToolsApprovalFilter struct { // ResponsesToolCodeInterpreter represents a tool code interpreter type ResponsesToolCodeInterpreter struct { - Container interface{} `json:"container"` // Container ID or object with file IDs + Container ResponsesToolCodeInterpreterContainer `json:"container"` // Either a string or a ResponsesToolCodeInterpreterContainerStruct object +} + +type ResponsesToolCodeInterpreterContainer struct { + ResponsesToolCodeInterpreterContainerString *string + ResponsesToolCodeInterpreterContainerStruct *ResponsesToolCodeInterpreterContainerStruct +} + +func (c *ResponsesToolCodeInterpreterContainer) MarshalJSON() ([]byte, error) { + if c.ResponsesToolCodeInterpreterContainerString != nil { + return Marshal(*c.ResponsesToolCodeInterpreterContainerString) + } + if c.ResponsesToolCodeInterpreterContainerStruct != nil { + return Marshal(*c.ResponsesToolCodeInterpreterContainerStruct) + } + return nil, fmt.Errorf("container field is neither a string nor a ResponsesToolCodeInterpreterContainerStruct object") +} + +func (c *ResponsesToolCodeInterpreterContainer) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var containerString string + if err := Unmarshal(data, &containerString); err == nil { + c.ResponsesToolCodeInterpreterContainerString = &containerString + return nil + } + // Try to unmarshal as a ResponsesToolCodeInterpreterContainerStruct object + var containerStruct ResponsesToolCodeInterpreterContainerStruct + if err := Unmarshal(data, &containerStruct); err == nil { + c.ResponsesToolCodeInterpreterContainerStruct = &containerStruct + return nil + } + return fmt.Errorf("container field is neither a string nor a ResponsesToolCodeInterpreterContainerStruct object") +} + +type ResponsesToolCodeInterpreterContainerStruct struct { + Type string `json:"type"` // always "auto" + MemoryLimit *string `json:"memory_limit,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` } // ResponsesToolImageGeneration represents a tool image generation diff --git a/tests/integrations/python/config.yml b/tests/integrations/python/config.yml index 1df0ca0ad..4f6f3f5d9 100644 --- a/tests/integrations/python/config.yml +++ b/tests/integrations/python/config.yml @@ -184,6 +184,7 @@ provider_scenarios: end2end_tool_calling: true automatic_function_calling: true "web_search": true + code_execution: true image_url: true image_base64: true file_input: true @@ -247,6 +248,7 @@ provider_scenarios: end2end_tool_calling: true automatic_function_calling: true web_search: true + code_execution: true image_url: true image_base64: true file_input: true @@ -289,6 +291,7 @@ provider_scenarios: multiple_tool_calls: true end2end_tool_calling: true automatic_function_calling: true + code_execution: true image_url: false # Gemini requires base64 or file upload image_base64: true file_input: true @@ -457,6 +460,7 @@ scenario_capabilities: end2end_tool_calling: "tools" automatic_function_calling: "tools" web_search: "chat" + code_execution: "tools" image_url: "vision" image_base64: "vision" file_input: "file" diff --git a/tests/integrations/python/tests/test_anthropic.py b/tests/integrations/python/tests/test_anthropic.py index 903cf10d6..dd780fc56 100644 --- a/tests/integrations/python/tests/test_anthropic.py +++ b/tests/integrations/python/tests/test_anthropic.py @@ -2310,11 +2310,134 @@ def validate_cache_read(usage: Any, operation: str) -> int: print(f"{operation} usage - input_tokens: {usage.input_tokens}, " f"cache_creation_input_tokens: {getattr(usage, 'cache_creation_input_tokens', 0)}, " f"cache_read_input_tokens: {getattr(usage, 'cache_read_input_tokens', 0)}") - + assert hasattr(usage, 'cache_read_input_tokens'), \ f"{operation} should have cache_read_input_tokens" cache_read_tokens = getattr(usage, 'cache_read_input_tokens', 0) assert cache_read_tokens > 0, \ f"{operation} should read from cache (got {cache_read_tokens} tokens)" - + return cache_read_tokens + + +class TestAnthropicCodeExecution: + """Tests for code execution tool with Anthropic SDK""" + + @pytest.fixture + def anthropic_client(self): + """Create Anthropic client for code execution tests""" + api_key = get_api_key("ANTHROPIC_API_KEY") + skip_if_no_api_key(api_key, "Anthropic") + + return Anthropic( + api_key=api_key, + base_url=os.getenv("BIFROST_BASE_URL", "http://localhost:8787") + "/v1" + ) + + def test_code_execution_math(self, anthropic_client): + """Test code execution with mathematical computation""" + response = anthropic_client.messages.create( + model="claude-sonnet-4-5", + max_tokens=4096, + messages=[ + { + "role": "user", + "content": "Calculate the sum of all prime numbers between 1 and 50 using Python code." + } + ], + tools=[ + { + "type": "code_execution_20250825", + "name": "code_execution" + } + ] + ) + + # Validate response + assert response is not None, "Response should not be None" + assert hasattr(response, "content"), "Response should have content" + assert len(response.content) > 0, "Response should have content blocks" + + # Check for text response + text_blocks = [block for block in response.content if hasattr(block, "type") and block.type == "text"] + assert len(text_blocks) > 0, "Response should have text blocks" + + # The result should mention primes or the sum (328) + response_text = " ".join([block.text for block in text_blocks]) + assert any(keyword in response_text.lower() for keyword in ["328", "prime", "sum"]), \ + f"Response should contain calculation result. Got: {response_text}" + + print(f"✓ Anthropic code execution (math) test passed!") + print(f" Response: {response_text[:200]}...") + + def test_code_execution_data_analysis(self, anthropic_client): + """Test code execution with statistical data analysis""" + response = anthropic_client.messages.create( + model="claude-sonnet-4-5", + max_tokens=4096, + messages=[ + { + "role": "user", + "content": "Calculate the mean and standard deviation of [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] using Python." + } + ], + tools=[ + { + "type": "code_execution_20250825", + "name": "code_execution" + } + ] + ) + + # Validate response + assert response is not None, "Response should not be None" + assert hasattr(response, "content"), "Response should have content" + assert len(response.content) > 0, "Response should have content blocks" + + # Check for text response + text_blocks = [block for block in response.content if hasattr(block, "type") and block.type == "text"] + assert len(text_blocks) > 0, "Response should have text blocks" + + # The result should mention mean (5.5) or standard deviation + response_text = " ".join([block.text for block in text_blocks]) + assert any(keyword in response_text.lower() for keyword in ["mean", "5.5", "standard deviation", "std"]), \ + f"Response should contain statistical results. Got: {response_text}" + + print(f"✓ Anthropic code execution (data analysis) test passed!") + print(f" Response: {response_text[:200]}...") + + def test_code_execution_equation_solving(self, anthropic_client): + """Test code execution with equation solving""" + response = anthropic_client.messages.create( + model="claude-sonnet-4-5", + max_tokens=4096, + messages=[ + { + "role": "user", + "content": "Solve the equation 3x + 11 = 14 for x using Python code." + } + ], + tools=[ + { + "type": "code_execution_20250825", + "name": "code_execution" + } + ] + ) + + # Validate response + assert response is not None, "Response should not be None" + assert hasattr(response, "content"), "Response should have content" + assert len(response.content) > 0, "Response should have content blocks" + + # Check for text response + text_blocks = [block for block in response.content if hasattr(block, "type") and block.type == "text"] + assert len(text_blocks) > 0, "Response should have text blocks" + + # The solution should be x = 1 + response_text = " ".join([block.text for block in text_blocks]) + assert any(keyword in response_text.lower() for keyword in ["x = 1", "x=1", "1.0", "solution"]), \ + f"Response should contain equation solution. Got: {response_text}" + + print(f"✓ Anthropic code execution (equation solving) test passed!") + print(f" Response: {response_text[:200]}...") diff --git a/tests/integrations/python/tests/test_google.py b/tests/integrations/python/tests/test_google.py index d17a45318..2944ba7b2 100644 --- a/tests/integrations/python/tests/test_google.py +++ b/tests/integrations/python/tests/test_google.py @@ -2695,3 +2695,141 @@ def extract_google_function_calls(response: Any) -> List[Dict[str, Any]]: continue return function_calls + + +class TestGoogleCodeExecution: + """Tests for code execution tool with Google Gemini SDK""" + + @pytest.fixture + def google_genai_client(self): + """Create Google GenAI client for code execution tests""" + from google import genai + from google.genai import types + + api_key = get_api_key("GOOGLE_API_KEY") + skip_if_no_api_key(api_key, "Google") + + # Configure client with Bifrost base URL + client = genai.Client( + api_key=api_key, + http_options={ + "api_version": "v1alpha" + } + ) + + # Store base URL for the client + base_url = os.getenv("BIFROST_BASE_URL", "http://localhost:8787") + client._base_url = base_url + + return client, types + + def test_code_execution_math(self, google_genai_client): + """Test code execution with mathematical computation""" + client, types = google_genai_client + + response = client.models.generate_content( + model="gemini-2.0-flash", + contents="Calculate the sum of all prime numbers between 1 and 50 using Python code. Show your calculation.", + config=types.GenerateContentConfig( + tools=[types.Tool(code_execution={})] + ) + ) + + # Validate response + assert response is not None, "Response should not be None" + assert hasattr(response, "candidates"), "Response should have candidates" + assert len(response.candidates) > 0, "Response should have at least one candidate" + + candidate = response.candidates[0] + assert hasattr(candidate, "content"), "Candidate should have content" + assert hasattr(candidate.content, "parts"), "Content should have parts" + + # Extract text from all parts + text_parts = [] + for part in candidate.content.parts: + if hasattr(part, "text") and part.text: + text_parts.append(part.text) + + response_text = " ".join(text_parts) + assert len(response_text) > 0, "Response should have text content" + + # The result should mention primes or the sum (328) + assert any(keyword in response_text.lower() for keyword in ["328", "prime", "sum"]), \ + f"Response should contain calculation result. Got: {response_text}" + + print(f"✓ Google code execution (math) test passed!") + print(f" Response: {response_text[:200]}...") + + def test_code_execution_data_analysis(self, google_genai_client): + """Test code execution with statistical data analysis""" + client, types = google_genai_client + + response = client.models.generate_content( + model="gemini-2.0-flash", + contents="Calculate the mean and standard deviation of these numbers: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]. Use Python to compute.", + config=types.GenerateContentConfig( + tools=[types.Tool(code_execution={})] + ) + ) + + # Validate response + assert response is not None, "Response should not be None" + assert hasattr(response, "candidates"), "Response should have candidates" + assert len(response.candidates) > 0, "Response should have at least one candidate" + + candidate = response.candidates[0] + assert hasattr(candidate, "content"), "Candidate should have content" + assert hasattr(candidate.content, "parts"), "Content should have parts" + + # Extract text from all parts + text_parts = [] + for part in candidate.content.parts: + if hasattr(part, "text") and part.text: + text_parts.append(part.text) + + response_text = " ".join(text_parts) + assert len(response_text) > 0, "Response should have text content" + + # The result should mention mean (5.5) or standard deviation + assert any(keyword in response_text.lower() for keyword in ["mean", "5.5", "average", "standard deviation", "std"]), \ + f"Response should contain statistical results. Got: {response_text}" + + print(f"✓ Google code execution (data analysis) test passed!") + print(f" Response: {response_text[:200]}...") + + def test_code_execution_equation_solving(self, google_genai_client): + """Test code execution with equation solving""" + client, types = google_genai_client + + response = client.models.generate_content( + model="gemini-2.0-flash", + contents="Solve the equation 3x + 11 = 14 for x using Python code.", + config=types.GenerateContentConfig( + tools=[types.Tool(code_execution={})] + ) + ) + + # Validate response + assert response is not None, "Response should not be None" + assert hasattr(response, "candidates"), "Response should have candidates" + assert len(response.candidates) > 0, "Response should have at least one candidate" + + candidate = response.candidates[0] + assert hasattr(candidate, "content"), "Candidate should have content" + assert hasattr(candidate.content, "parts"), "Content should have parts" + + # Extract text from all parts + text_parts = [] + for part in candidate.content.parts: + if hasattr(part, "text") and part.text: + text_parts.append(part.text) + + response_text = " ".join(text_parts) + assert len(response_text) > 0, "Response should have text content" + + # The solution should be x = 1 + assert any(keyword in response_text.lower() for keyword in ["x = 1", "x=1", "1.0", "1", "solution"]), \ + f"Response should contain equation solution. Got: {response_text}" + + print(f"✓ Google code execution (equation solving) test passed!") + print(f" Response: {response_text[:200]}...") diff --git a/tests/integrations/python/tests/test_openai.py b/tests/integrations/python/tests/test_openai.py index cf34d8abb..b9dd4251c 100644 --- a/tests/integrations/python/tests/test_openai.py +++ b/tests/integrations/python/tests/test_openai.py @@ -3042,3 +3042,116 @@ def test_53_web_search_streaming(self, provider, model, vk_enabled): print(f"✓ Web search (streaming) test passed!") + @pytest.mark.parametrize( + "provider,model,vk_enabled", get_cross_provider_params_with_vk_for_scenario("code_execution") + ) + def test_64_code_execution_math(self, test_config, provider, model, vk_enabled): + """Test Case 64: Code Execution - Mathematical Computation""" + client = get_provider_openai_client(provider, vk_enabled=vk_enabled) + + # Test solving a mathematical equation using code execution + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "user", "content": "Calculate the sum of all prime numbers between 1 and 50 using Python. Show your work."} + ], + tools=[ + { + "type": "code_interpreter" + } + ], + extra_body={"provider": provider} if not vk_enabled else None, + extra_query={"vk": "true"} if vk_enabled else None + ) + + # Validate response + assert_valid_chat_response(response) + assert response.choices, "Response should have choices" + assert len(response.choices) > 0, "Should have at least one choice" + + message = response.choices[0].message + assert message.content, "Message should have content" + + # The response should contain the result (sum of primes 1-50 = 328) + content_text = message.content.lower() + assert any(keyword in content_text for keyword in ["328", "prime", "sum"]), \ + f"Response should contain calculation result. Got: {message.content}" + + print(f"✓ Code execution (math) test passed!") + print(f" Response: {message.content[:200]}...") + + @pytest.mark.parametrize( + "provider,model,vk_enabled", get_cross_provider_params_with_vk_for_scenario("code_execution") + ) + def test_65_code_execution_data_analysis(self, test_config, provider, model, vk_enabled): + """Test Case 65: Code Execution - Data Analysis""" + client = get_provider_openai_client(provider, vk_enabled=vk_enabled) + + # Test statistical analysis using code execution + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "user", "content": "Calculate the mean and standard deviation of the following numbers: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]. Use Python to compute these statistics."} + ], + tools=[ + { + "type": "code_interpreter" + } + ], + extra_body={"provider": provider} if not vk_enabled else None, + extra_query={"vk": "true"} if vk_enabled else None + ) + + # Validate response + assert_valid_chat_response(response) + assert response.choices, "Response should have choices" + + message = response.choices[0].message + assert message.content, "Message should have content" + + # The response should contain statistical calculations + # Mean should be 5.5, std dev ~2.87 + content_text = message.content.lower() + assert any(keyword in content_text for keyword in ["mean", "average", "5.5", "standard deviation", "std"]), \ + f"Response should contain statistical results. Got: {message.content}" + + print(f"✓ Code execution (data analysis) test passed!") + print(f" Response: {message.content[:200]}...") + + @pytest.mark.parametrize( + "provider,model,vk_enabled", get_cross_provider_params_with_vk_for_scenario("code_execution") + ) + def test_66_code_execution_equation_solving(self, test_config, provider, model, vk_enabled): + """Test Case 66: Code Execution - Equation Solving""" + client = get_provider_openai_client(provider, vk_enabled=vk_enabled) + + # Test equation solving using code execution + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "user", "content": "Solve the equation 3x + 11 = 14 for x using Python."} + ], + tools=[ + { + "type": "code_interpreter" + } + ], + extra_body={"provider": provider} if not vk_enabled else None, + extra_query={"vk": "true"} if vk_enabled else None + ) + + # Validate response + assert_valid_chat_response(response) + assert response.choices, "Response should have choices" + + message = response.choices[0].message + assert message.content, "Message should have content" + + # The solution should be x = 1 + content_text = message.content.lower() + assert any(keyword in content_text for keyword in ["x = 1", "x=1", "1.0", "solution"]), \ + f"Response should contain equation solution. Got: {message.content}" + + print(f"✓ Code execution (equation solving) test passed!") + print(f" Response: {message.content[:200]}...") + diff --git a/transports/changelog.md b/transports/changelog.md index 5301cc548..4b63a0795 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -3,4 +3,5 @@ - fix: add support for AdditionalProperties structures (both boolean and object types) - fix: improve thought signature handling in gemini for function calls - fix: enhance citations structure to support multiple citation types -- fix: anthropic streaming events through integration \ No newline at end of file +- fix: anthropic streaming events through integration +- feat: added support for code execution tool for openai, anthropic and gemini \ No newline at end of file