diff --git a/core/providers/gemini/chat.go b/core/providers/gemini/chat.go index 948b65f9e..81cac9249 100644 --- a/core/providers/gemini/chat.go +++ b/core/providers/gemini/chat.go @@ -4,285 +4,12 @@ import ( "encoding/base64" "encoding/json" "fmt" - "strings" - "time" "github.com/maximhq/bifrost/core/schemas" ) -func (request *GeminiGenerationRequest) ToBifrostChatRequest() *schemas.BifrostChatRequest { - provider, model := schemas.ParseModelString(request.Model, schemas.Gemini) - - if provider == schemas.Vertex && !request.IsEmbedding { - // Add google/ prefix if not already present and model is not a custom fine-tuned model - if !schemas.IsAllDigitsASCII(model) && !strings.HasPrefix(model, "google/") { - model = "google/" + model - } - } - - // Handle chat completion requests - bifrostReq := &schemas.BifrostChatRequest{ - Provider: provider, - Model: model, - Input: []schemas.ChatMessage{}, - Fallbacks: schemas.ParseFallbacks(request.Fallbacks), - } - - messages := []schemas.ChatMessage{} - // Track all tool calls from previous messages for function response correlation - previousToolCalls := []schemas.ChatAssistantMessageToolCall{} - - allGenAiMessages := []Content{} - if request.SystemInstruction != nil { - allGenAiMessages = append(allGenAiMessages, *request.SystemInstruction) - } - allGenAiMessages = append(allGenAiMessages, request.Contents...) - - for _, content := range allGenAiMessages { - if len(content.Parts) == 0 { - continue - } - - // Handle multiple parts - collect all content and tool calls - var toolCalls []schemas.ChatAssistantMessageToolCall - var contentBlocks []schemas.ChatContentBlock - var thoughtStr string // Track thought content for assistant/model - - for _, part := range content.Parts { - switch { - case part.Text != "": - // Handle thought content specially for assistant messages - if part.Thought && - (content.Role == string(schemas.ChatMessageRoleAssistant) || content.Role == string(RoleModel)) { - thoughtStr = thoughtStr + part.Text + "\n" - } else { - contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ - Type: schemas.ChatContentBlockTypeText, - Text: &part.Text, - }) - } - - case part.FunctionCall != nil: - // Only add function calls for assistant messages - if content.Role == string(schemas.ChatMessageRoleAssistant) || content.Role == string(RoleModel) { - jsonArgs, err := json.Marshal(part.FunctionCall.Args) - if err != nil { - jsonArgs = []byte(fmt.Sprintf("%v", part.FunctionCall.Args)) - } - name := part.FunctionCall.Name // create local copy - // Gemini primarily works with function names for correlation - // Use ID if provided, otherwise fallback to name for stable correlation - callID := name - if strings.TrimSpace(part.FunctionCall.ID) != "" { - callID = part.FunctionCall.ID - } - toolCall := schemas.ChatAssistantMessageToolCall{ - Index: uint16(len(toolCalls)), - ID: schemas.Ptr(callID), - Type: schemas.Ptr(string(schemas.ChatToolChoiceTypeFunction)), - Function: schemas.ChatAssistantMessageToolCallFunction{ - Name: &name, - Arguments: string(jsonArgs), - }, - } - - // Preserve thought signature if present (required for Gemini 3 Pro) - if len(part.ThoughtSignature) > 0 { - toolCall.ExtraContent = map[string]interface{}{ - "google": map[string]interface{}{ - "thought_signature": string(part.ThoughtSignature), - }, - } - } - - toolCalls = append(toolCalls, toolCall) - } - - case part.FunctionResponse != nil: - // Create a separate tool response message - responseContent, err := json.Marshal(part.FunctionResponse.Response) - if err != nil { - responseContent = []byte(fmt.Sprintf("%v", part.FunctionResponse.Response)) - } - - // Correlate with the function call: prefer ID if available, otherwise use name - callID := part.FunctionResponse.Name - if strings.TrimSpace(part.FunctionResponse.ID) != "" { - callID = part.FunctionResponse.ID - } else { - // Fallback: search through all previous tool calls to find matching one by name - for _, tc := range previousToolCalls { - if tc.Function.Name != nil && *tc.Function.Name == part.FunctionResponse.Name && - tc.ID != nil && *tc.ID != "" { - callID = *tc.ID - break - } - } - } - - toolResponseMsg := schemas.ChatMessage{ - Role: schemas.ChatMessageRoleTool, - Content: &schemas.ChatMessageContent{ - ContentStr: schemas.Ptr(string(responseContent)), - }, - ChatToolMessage: &schemas.ChatToolMessage{ - ToolCallID: &callID, - }, - } - - messages = append(messages, toolResponseMsg) - - case part.InlineData != nil: - // Handle inline images/media - only append if it's actually an image - if isImageMimeType(part.InlineData.MIMEType) { - contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ - Type: schemas.ChatContentBlockTypeImage, - ImageURLStruct: &schemas.ChatInputImage{ - URL: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MIMEType, base64.StdEncoding.EncodeToString(part.InlineData.Data)), - }, - }) - } - - case part.FileData != nil: - // Handle file data - only append if it's actually an image - if isImageMimeType(part.FileData.MIMEType) { - contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ - Type: schemas.ChatContentBlockTypeImage, - ImageURLStruct: &schemas.ChatInputImage{ - URL: part.FileData.FileURI, - }, - }) - } - - case part.ExecutableCode != nil: - // Handle executable code as text content - codeText := fmt.Sprintf("```%s\n%s\n```", part.ExecutableCode.Language, part.ExecutableCode.Code) - contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ - Type: schemas.ChatContentBlockTypeText, - Text: &codeText, - }) - - case part.CodeExecutionResult != nil: - // Handle code execution results as text content - resultText := fmt.Sprintf("Code execution result (%s):\n%s", part.CodeExecutionResult.Outcome, part.CodeExecutionResult.Output) - contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ - Type: schemas.ChatContentBlockTypeText, - Text: &resultText, - }) - } - } - - // Only create message if there's actual content, tool calls, or thought content - if len(contentBlocks) > 0 || len(toolCalls) > 0 || thoughtStr != "" { - // Create main message with content blocks - bifrostMsg := schemas.ChatMessage{ - Role: func(r string) schemas.ChatMessageRole { - if r == string(RoleModel) { // GenAI's internal alias - return schemas.ChatMessageRoleAssistant - } - return schemas.ChatMessageRole(r) - }(content.Role), - } - - // Set content only if there are content blocks - if len(contentBlocks) > 0 { - bifrostMsg.Content = &schemas.ChatMessageContent{ - ContentBlocks: contentBlocks, - } - } - - // Set assistant-specific fields for assistant/model messages - if content.Role == string(schemas.ChatMessageRoleAssistant) || content.Role == string(RoleModel) { - if len(toolCalls) > 0 || thoughtStr != "" { - bifrostMsg.ChatAssistantMessage = &schemas.ChatAssistantMessage{} - if len(toolCalls) > 0 { - bifrostMsg.ChatAssistantMessage.ToolCalls = toolCalls - // Track these tool calls for future function response correlation - previousToolCalls = append(previousToolCalls, toolCalls...) - } - } - } - - messages = append(messages, bifrostMsg) - } - } - - bifrostReq.Input = messages - - // Convert generation config to parameters - if params := request.convertGenerationConfigToChatParameters(); params != nil { - bifrostReq.Params = params - } - - // Convert safety settings - if len(request.SafetySettings) > 0 { - ensureExtraParams(bifrostReq) - bifrostReq.Params.ExtraParams["safety_settings"] = request.SafetySettings - } - - // Convert additional request fields - if request.CachedContent != "" { - ensureExtraParams(bifrostReq) - bifrostReq.Params.ExtraParams["cached_content"] = request.CachedContent - } - - // Convert labels - if len(request.Labels) > 0 { - ensureExtraParams(bifrostReq) - bifrostReq.Params.ExtraParams["labels"] = request.Labels - } - - // Convert tools and tool config - if len(request.Tools) > 0 { - ensureExtraParams(bifrostReq) - - tools := make([]schemas.ChatTool, 0, len(request.Tools)) - for _, tool := range request.Tools { - if len(tool.FunctionDeclarations) > 0 { - for _, fn := range tool.FunctionDeclarations { - bifrostTool := schemas.ChatTool{ - Type: schemas.ChatToolTypeFunction, - Function: &schemas.ChatToolFunction{ - Name: fn.Name, - Description: schemas.Ptr(fn.Description), - }, - } - // Convert parameters schema if present - if fn.Parameters != nil { - params := request.convertSchemaToFunctionParameters(fn.Parameters) - bifrostTool.Function.Parameters = ¶ms - } - tools = append(tools, bifrostTool) - } - } - // Handle other tool types (Retrieval, GoogleSearch, etc.) as ExtraParams - if tool.Retrieval != nil { - bifrostReq.Params.ExtraParams["retrieval"] = tool.Retrieval - } - if tool.GoogleSearch != nil { - bifrostReq.Params.ExtraParams["google_search"] = tool.GoogleSearch - } - if tool.CodeExecution != nil { - bifrostReq.Params.ExtraParams["code_execution"] = tool.CodeExecution - } - } - - if len(tools) > 0 { - bifrostReq.Params.Tools = tools - } - } - - // Convert tool config - if request.ToolConfig.FunctionCallingConfig != nil || request.ToolConfig.RetrievalConfig != nil { - ensureExtraParams(bifrostReq) - bifrostReq.Params.ExtraParams["tool_config"] = request.ToolConfig - } - - return bifrostReq -} - // ToGeminiChatCompletionRequest converts a BifrostChatRequest to Gemini's generation request format for chat completion -func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest, responseModalities []string) *GeminiGenerationRequest { +func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) *GeminiGenerationRequest { if bifrostReq == nil { return nil } @@ -294,7 +21,7 @@ func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest, respo // Convert parameters to generation config if bifrostReq.Params != nil { - geminiReq.GenerationConfig = convertParamsToGenerationConfig(bifrostReq.Params, responseModalities) + geminiReq.GenerationConfig = convertParamsToGenerationConfig(bifrostReq.Params, []string{}) // Handle tool-related parameters if len(bifrostReq.Params.Tools) > 0 { @@ -348,213 +75,265 @@ func (response *GenerateContentResponse) ToBifrostChatResponse() *schemas.Bifros bifrostResp.Created = int(response.CreateTime.Unix()) } - // Extract usage metadata - inputTokens, outputTokens, totalTokens, cachedTokens, reasoningTokens := response.extractUsageMetadata() + // Collect all content and tool calls into a single message + var toolCalls []schemas.ChatAssistantMessageToolCall + var contentBlocks []schemas.ChatContentBlock + var reasoningDetails []schemas.ChatReasoningDetails // Process candidates to extract text content if len(response.Candidates) > 0 { candidate := response.Candidates[0] if candidate.Content != nil && len(candidate.Content.Parts) > 0 { - var textContent string - - // Extract text content from all parts for _, part := range candidate.Content.Parts { if part.Text != "" { - textContent += part.Text + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, + Text: &part.Text, + }) } - } - if textContent != "" { - // Create choice from the candidate - choice := schemas.BifrostResponseChoice{ - Index: 0, - ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ - Message: &schemas.ChatMessage{ - Role: schemas.ChatMessageRoleAssistant, - Content: &schemas.ChatMessageContent{ - ContentStr: &textContent, - }, - }, - }, + if part.FunctionCall != nil { + function := schemas.ChatAssistantMessageToolCallFunction{ + Name: &part.FunctionCall.Name, + } + + if part.FunctionCall.Args != nil { + jsonArgs, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + jsonArgs = []byte(fmt.Sprintf("%v", part.FunctionCall.Args)) + } + function.Arguments = string(jsonArgs) + } + + callID := part.FunctionCall.Name + if part.FunctionCall.ID != "" { + callID = part.FunctionCall.ID + } + + toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: schemas.Ptr(string(schemas.ChatToolChoiceTypeFunction)), + ID: &callID, + Function: function, + }) } - // Set finish reason if available - if candidate.FinishReason != "" { - finishReason := string(candidate.FinishReason) - choice.FinishReason = &finishReason + if part.FunctionResponse != nil { + // Extract the output from the response + output := extractFunctionResponseOutput(part.FunctionResponse) + + // Add as text content block + if output != "" { + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, + Text: &output, + }) + } + } + if part.ThoughtSignature != nil { + thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature) + reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{ + Index: len(reasoningDetails), + Type: schemas.BifrostReasoningDetailsTypeEncrypted, + Signature: &thoughtSig, + }) + } + } + + // Build the choice with message + message := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + } + + if len(contentBlocks) > 0 { + message.Content = &schemas.ChatMessageContent{ + ContentBlocks: contentBlocks, } + } - bifrostResp.Choices = []schemas.BifrostResponseChoice{choice} + if len(toolCalls) > 0 || len(reasoningDetails) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ + ToolCalls: toolCalls, + ReasoningDetails: reasoningDetails, + } } + + bifrostResp.Choices = append(bifrostResp.Choices, schemas.BifrostResponseChoice{ + Index: 0, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: message, + }, + }) } } // Set usage information - bifrostResp.Usage = &schemas.BifrostLLMUsage{ - PromptTokens: inputTokens, - CompletionTokens: outputTokens, - TotalTokens: totalTokens, - PromptTokensDetails: &schemas.ChatPromptTokensDetails{ - CachedTokens: cachedTokens, - }, - CompletionTokensDetails: &schemas.ChatCompletionTokensDetails{ - ReasoningTokens: reasoningTokens, - }, - } + bifrostResp.Usage = convertGeminiUsageMetadataToChatUsage(response.UsageMetadata) return bifrostResp } -// ToGeminiChatResponse converts a BifrostChatResponse to Gemini's GenerateContentResponse -func ToGeminiChatResponse(bifrostResp *schemas.BifrostChatResponse) *GenerateContentResponse { - if bifrostResp == nil { - return nil +// ToBifrostChatCompletionStream converts a Gemini streaming response to a Bifrost Chat Completion Stream response +// Returns the response, error (if any), and a boolean indicating if this is the last chunk +func (response *GenerateContentResponse) ToBifrostChatCompletionStream() (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) { + if response == nil { + return nil, nil, false } - genaiResp := &GenerateContentResponse{ - ResponseID: bifrostResp.ID, - ModelVersion: bifrostResp.Model, + // Check if we have candidates with content + if len(response.Candidates) == 0 { + return nil, nil, false + } + + candidate := response.Candidates[0] + + // Determine if this is the last chunk based on finish reason or usage metadata + isLastChunk := candidate.FinishReason != "" && candidate.FinishReason != FinishReasonUnspecified + + // Create the streaming response + streamResponse := &schemas.BifrostChatResponse{ + ID: response.ResponseID, + Model: response.ModelVersion, + Object: "chat.completion.chunk", } - // Set creation time if available - if bifrostResp.Created > 0 { - genaiResp.CreateTime = time.Unix(int64(bifrostResp.Created), 0) + // Set creation timestamp if available + if !response.CreateTime.IsZero() { + streamResponse.Created = int(response.CreateTime.Unix()) } - if len(bifrostResp.Choices) > 0 { - candidates := make([]*Candidate, len(bifrostResp.Choices)) + // Build delta content + delta := &schemas.ChatStreamResponseChoiceDelta{} - for i, choice := range bifrostResp.Choices { - candidate := &Candidate{ - Index: int32(choice.Index), + // Process content parts + if candidate.Content != nil && len(candidate.Content.Parts) > 0 { + // Set role from the first chunk (Gemini uses "model" for assistant) + if candidate.Content.Role != "" { + role := candidate.Content.Role + if role == string(RoleModel) { + role = string(schemas.ChatMessageRoleAssistant) } + delta.Role = &role + } - if choice.FinishReason != nil { - candidate.FinishReason = FinishReason(*choice.FinishReason) - } + var textContent string + var thoughtContent string + var toolCalls []schemas.ChatAssistantMessageToolCall + var reasoningDetails []schemas.ChatReasoningDetails - // Convert message content to Gemini parts - var parts []*Part - var role string + for _, part := range candidate.Content.Parts { + switch { + case part.Text != "" && part.Thought: + // Thought/reasoning content + thoughtContent += part.Text - // Handle streaming responses - if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { - delta := choice.ChatStreamResponseChoice.Delta + case part.Text != "": + // Regular text content + textContent += part.Text - // Set role from delta if available - if delta.Role != nil { - role = *delta.Role - } else { - role = "model" // Default role for streaming responses + case part.FunctionCall != nil: + // Function call + jsonArgs := "" + if part.FunctionCall.Args != nil { + if argsBytes, err := json.Marshal(part.FunctionCall.Args); err == nil { + jsonArgs = string(argsBytes) + } } - // Handle content text - if delta.Content != nil && *delta.Content != "" { - parts = append(parts, &Part{Text: *delta.Content}) + // Use ID if available, otherwise use function name + callID := part.FunctionCall.Name + if part.FunctionCall.ID != "" { + callID = part.FunctionCall.ID } - // Handle tool calls in streaming - if delta.ToolCalls != nil { - for _, toolCall := range delta.ToolCalls { - argsMap := make(map[string]interface{}) - if toolCall.Function.Arguments != "" { - json.Unmarshal([]byte(toolCall.Function.Arguments), &argsMap) - } - if toolCall.Function.Name != nil { - fc := &FunctionCall{ - Name: *toolCall.Function.Name, - Args: argsMap, - } - if toolCall.ID != nil { - fc.ID = *toolCall.ID - } - parts = append(parts, &Part{FunctionCall: fc}) - } - } + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: schemas.Ptr(string(schemas.ChatToolTypeFunction)), + ID: &callID, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &part.FunctionCall.Name, + Arguments: jsonArgs, + }, } - if len(parts) > 0 { - candidate.Content = &Content{ - Parts: parts, - Role: role, - } - } - } else if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil { - // Handle non-streaming responses - if choice.ChatNonStreamResponseChoice.Message.Content != nil { - if choice.ChatNonStreamResponseChoice.Message.Content.ContentStr != nil && *choice.ChatNonStreamResponseChoice.Message.Content.ContentStr != "" { - parts = append(parts, &Part{Text: *choice.ChatNonStreamResponseChoice.Message.Content.ContentStr}) - } else if choice.ChatNonStreamResponseChoice.Message.Content.ContentBlocks != nil { - for _, block := range choice.ChatNonStreamResponseChoice.Message.Content.ContentBlocks { - if block.Text != nil { - parts = append(parts, &Part{Text: *block.Text}) - } - } + // Preserve thought signature if present (required for Gemini 3 Pro) + if len(part.ThoughtSignature) > 0 { + toolCall.ExtraContent = map[string]interface{}{ + "google": map[string]interface{}{ + "thought_signature": string(part.ThoughtSignature), + }, } } - // Handle tool calls - if choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil && choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls != nil { - for _, toolCall := range choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls { - argsMap := make(map[string]interface{}) - if toolCall.Function.Arguments != "" { - if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &argsMap); err != nil { - argsMap = map[string]interface{}{} - } - } - if toolCall.Function.Name != nil { - fc := &FunctionCall{ - Name: *toolCall.Function.Name, - Args: argsMap, - } - if toolCall.ID != nil { - fc.ID = *toolCall.ID - } - - part := &Part{FunctionCall: fc} - - // Preserve thought signature from extra_content (required for Gemini 3 Pro) - if toolCall.ExtraContent != nil { - if googleData, ok := toolCall.ExtraContent["google"].(map[string]interface{}); ok { - if thoughtSig, ok := googleData["thought_signature"].(string); ok { - part.ThoughtSignature = []byte(thoughtSig) - } - } - } + toolCalls = append(toolCalls, toolCall) - parts = append(parts, part) - } + case part.FunctionResponse != nil: + // Extract the output from the response and add to text content + output := extractFunctionResponseOutput(part.FunctionResponse) + if output != "" { + textContent += output } } - if len(parts) > 0 { - candidate.Content = &Content{ - Parts: parts, - Role: string(choice.ChatNonStreamResponseChoice.Message.Role), - } - } + // Handle thought signature separately (not part of the switch since it can co-exist with other types) + if part.ThoughtSignature != nil { + thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature) + reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{ + Index: len(reasoningDetails), + Type: schemas.BifrostReasoningDetailsTypeEncrypted, + Signature: &thoughtSig, + }) } - - candidates[i] = candidate } - genaiResp.Candidates = candidates - } + // Set text content if present + if textContent != "" { + delta.Content = &textContent + } - // Set usage metadata from LLM usage - if bifrostResp.Usage != nil { - genaiResp.UsageMetadata = &GenerateContentResponseUsageMetadata{ - PromptTokenCount: int32(bifrostResp.Usage.PromptTokens), - CandidatesTokenCount: int32(bifrostResp.Usage.CompletionTokens), - TotalTokenCount: int32(bifrostResp.Usage.TotalTokens), + // Set thought content if present + if thoughtContent != "" { + delta.Reasoning = &thoughtContent } - if bifrostResp.Usage.PromptTokensDetails != nil { - genaiResp.UsageMetadata.CachedContentTokenCount = int32(bifrostResp.Usage.PromptTokensDetails.CachedTokens) + + // Set reasoning details if present + if len(reasoningDetails) > 0 { + delta.ReasoningDetails = reasoningDetails } - if bifrostResp.Usage.CompletionTokensDetails != nil { - genaiResp.UsageMetadata.ThoughtsTokenCount = int32(bifrostResp.Usage.CompletionTokensDetails.ReasoningTokens) + + // Set tool calls if present + if len(toolCalls) > 0 { + delta.ToolCalls = toolCalls } } - return genaiResp + // Check if delta has any content - if not and it's not the last chunk, skip it + hasDeltaContent := delta.Role != nil || delta.Content != nil || delta.Reasoning != nil || len(delta.ToolCalls) > 0 || len(delta.ReasoningDetails) > 0 + if !hasDeltaContent && !isLastChunk { + return nil, nil, false + } + + // Build the choice + var finishReason *string + if isLastChunk && candidate.FinishReason != "" { + reason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason) + finishReason = &reason + } + + choice := schemas.BifrostResponseChoice{ + Index: int(candidate.Index), + FinishReason: finishReason, + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ + Delta: delta, + }, + } + + streamResponse.Choices = []schemas.BifrostResponseChoice{choice} + + // Add usage information if this is the last chunk + if isLastChunk && response.UsageMetadata != nil { + streamResponse.Usage = convertGeminiUsageMetadataToChatUsage(response.UsageMetadata) + } + + return streamResponse, nil, isLastChunk } diff --git a/core/providers/gemini/embedding.go b/core/providers/gemini/embedding.go index 160264605..4beea9ccb 100644 --- a/core/providers/gemini/embedding.go +++ b/core/providers/gemini/embedding.go @@ -107,13 +107,6 @@ func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest() *schemas.Bif provider, model := schemas.ParseModelString(request.Model, schemas.Gemini) - if provider == schemas.Vertex && !request.IsEmbedding { - // Add google/ prefix if not already present and model is not a custom fine-tuned model - if !schemas.IsAllDigitsASCII(model) && !strings.HasPrefix(model, "google/") { - model = "google/" + model - } - } - // Create the embedding request bifrostReq := &schemas.BifrostEmbeddingRequest{ Provider: provider, diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index 12b9a2318..f18ab271a 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -225,145 +225,612 @@ func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas. jsonData, err := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (any, error) { return openai.ToOpenAIChatRequest(request), nil }, + func() (any, error) { return ToGeminiChatCompletionRequest(request), nil }, provider.GetProviderKey()) if err != nil { return nil, err } - // Create request + geminiResponse, rawResponse, latency, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") + if bifrostErr != nil { + return nil, bifrostErr + } + + bifrostResponse := geminiResponse.ToBifrostChatResponse() + + bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest + bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.Latency = latency.Milliseconds() + + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + return bifrostResponse, nil +} + +// ChatCompletionStream performs a streaming chat completion request to the Gemini API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *GeminiProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + // Check if chat completion stream is allowed for this provider + if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { + return nil, err + } + + jsonData, err := providerUtils.CheckContextAndGetRequestBody( + ctx, + request, + func() (any, error) { + reqBody := ToGeminiChatCompletionRequest(request) + if reqBody == nil { + return nil, fmt.Errorf("chat completion request is not provided or could not be converted to Gemini format") + } + return reqBody, nil + }, + provider.GetProviderKey()) + if err != nil { + return nil, err + } + + // Prepare Gemini headers + headers := map[string]string{ + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + if key.Value != "" { + headers["x-goog-api-key"] = key.Value + } + + // Use shared Gemini streaming logic + return HandleGeminiChatCompletionStream( + ctx, + provider.client, + provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/models/"+request.Model+":streamGenerateContent?alt=sse"), + jsonData, + headers, + provider.networkConfig.ExtraHeaders, + providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + provider.GetProviderKey(), + request.Model, + postHookRunner, + nil, + provider.logger, + ) +} + +// HandleGeminiChatCompletionStream handles streaming for Gemini-compatible APIs. +func HandleGeminiChatCompletionStream( + ctx context.Context, + client *fasthttp.Client, + url string, + jsonBody []byte, + headers map[string]string, + extraHeaders map[string]string, + sendBackRawResponse bool, + providerName schemas.ModelProvider, + model string, + postHookRunner schemas.PostHookRunner, + postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse, + logger schemas.Logger, +) (chan *schemas.BifrostStream, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() + resp.StreamBody = true defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - // Set any extra headers from network config - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/openai/chat/completions")) req.Header.SetMethod(http.MethodPost) + req.SetRequestURI(url) req.Header.SetContentType("application/json") - if key.Value != "" { - req.Header.Set("Authorization", "Bearer "+key.Value) + providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil) + + // Set headers + for key, value := range headers { + req.Header.Set(key, value) } - req.SetBody(jsonData) + req.SetBody(jsonBody) - // Make request - latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr + // Make the request + doErr := client.Do(req, resp) + if doErr != nil { + defer providerUtils.ReleaseStreamingResponse(resp) + if errors.Is(doErr, context.Canceled) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Type: schemas.Ptr(schemas.RequestCancelled), + Message: schemas.ErrRequestCancelled, + Error: doErr, + }, + } + } + if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestTimedOut, doErr, providerName) + } + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName) } - // Handle error response + // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp) + defer providerUtils.ReleaseStreamingResponse(resp) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerName, resp.StatusCode()), fmt.Errorf("%s", string(resp.Body())), resp.StatusCode(), providerName, nil, nil) } - body, decodeErr := providerUtils.CheckAndDecodeBody(resp) - if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) - } + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) - response := &schemas.BifrostChatResponse{} + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer providerUtils.ReleaseStreamingResponse(resp) - rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) - if bifrostErr != nil { - return nil, bifrostErr - } + scanner := bufio.NewScanner(resp.BodyStream()) + buf := make([]byte, 0, 1024*1024) + scanner.Buffer(buf, 10*1024*1024) - for _, choice := range response.Choices { - if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil && choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil && choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls != nil { - for i, toolCall := range choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls { - if (toolCall.ID == nil || *toolCall.ID == "") && toolCall.Function.Name != nil && *toolCall.Function.Name != "" { - id := "" - if toolCall.Function.Name != nil { - id = *toolCall.Function.Name + chunkIndex := 0 + startTime := time.Now() + lastChunkTime := startTime + + var responseID string + var modelName string + + for scanner.Scan() { + line := scanner.Text() + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE data + if !strings.HasPrefix(line, "data: ") { + continue + } + + eventData := strings.TrimPrefix(line, "data: ") + + // Skip empty data + if strings.TrimSpace(eventData) == "" { + continue + } + + // Process chunk using shared function + geminiResponse, err := processGeminiStreamChunk(eventData) + if err != nil { + if strings.Contains(err.Error(), "gemini api error") { + // Handle API error + bifrostErr := &schemas.BifrostError{ + Type: schemas.Ptr("gemini_api_error"), + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerName, + ModelRequested: model, + }, + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + return + } + logger.Warn(fmt.Sprintf("Failed to process chunk: %v", err)) + continue + } + + // Track response ID and model + if geminiResponse.ResponseID != "" && responseID == "" { + responseID = geminiResponse.ResponseID + } + if geminiResponse.ModelVersion != "" && modelName == "" { + modelName = geminiResponse.ModelVersion + } + + // Convert to Bifrost stream response + response, bifrostErr, isLastChunk := geminiResponse.ToBifrostChatCompletionStream() + if bifrostErr != nil { + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerName, + ModelRequested: model, + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + return + } + + if response != nil { + response.ID = responseID + if modelName != "" { + response.Model = modelName + } + response.ExtraFields = schemas.BifrostResponseExtraFields{ + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerName, + ModelRequested: model, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + } + + if postResponseConverter != nil { + response = postResponseConverter(response) + if response == nil { + logger.Warn("postResponseConverter returned nil; skipping chunk") + continue } - (choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls)[i].ID = &id } + + if sendBackRawResponse { + response.ExtraFields.RawResponse = eventData + } + + lastChunkTime = time.Now() + chunkIndex++ + + if isLastChunk { + response.ExtraFields.Latency = time.Since(startTime).Milliseconds() + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + break + } + + // Process response through post-hooks and send to channel + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) } } + + // Handle scanner errors + if err := scanner.Err(); err != nil { + logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, model, logger) + } + }() + + return responseChan, nil +} + +// Responses performs a chat completion request to Gemini's API. +// It formats the request, sends it to Gemini, and processes the response. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + return nil, err } - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.Latency = latency.Milliseconds() + // Convert to Gemini format using the centralized converter + jsonData, err := providerUtils.CheckContextAndGetRequestBody( + ctx, + request, + func() (any, error) { + reqBody := ToGeminiResponsesRequest(request) + if reqBody == nil { + return nil, fmt.Errorf("responses input is not provided or could not be converted to Gemini format") + } + return reqBody, nil + }, + provider.GetProviderKey()) + if err != nil { + return nil, err + } + + // Use struct directly for JSON marshaling + geminiResponse, rawResponse, latency, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := geminiResponse.ToResponsesBifrostResponsesResponse() + + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest + bifrostResponse.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { - response.ExtraFields.RawRequest = rawRequest + providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) } // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { - response.ExtraFields.RawResponse = rawResponse + bifrostResponse.ExtraFields.RawResponse = rawResponse } - return response, nil + return bifrostResponse, nil } -// ChatCompletionStream performs a streaming chat completion request to the Gemini API. -// It supports real-time streaming of responses using Server-Sent Events (SSE). -// Uses Gemini's OpenAI-compatible streaming format. -// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *GeminiProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - // Check if chat completion stream is allowed for this provider - if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { +// ResponsesStream performs a streaming responses request to the Gemini API. +func (provider *GeminiProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + // Check if responses stream is allowed for this provider + if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { + return nil, err + } + + jsonData, err := providerUtils.CheckContextAndGetRequestBody( + ctx, + request, + func() (any, error) { + reqBody := ToGeminiResponsesRequest(request) + if reqBody == nil { + return nil, fmt.Errorf("responses input is not provided or could not be converted to Gemini format") + } + return reqBody, nil + }, + provider.GetProviderKey()) + if err != nil { return nil, err } - var authHeader map[string]string + // Prepare Gemini headers + headers := map[string]string{ + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } if key.Value != "" { - authHeader = map[string]string{"Authorization": "Bearer " + key.Value} + headers["x-goog-api-key"] = key.Value } - // Use shared OpenAI-compatible streaming logic - return openai.HandleOpenAIChatCompletionStreaming( + return HandleGeminiResponsesStream( ctx, provider.client, - provider.networkConfig.BaseURL+"/openai/chat/completions", - request, - authHeader, + provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/models/"+request.Model+":streamGenerateContent?alt=sse"), + jsonData, + headers, provider.networkConfig.ExtraHeaders, - providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + request.Model, postHookRunner, nil, - nil, - nil, provider.logger, ) } -// Responses performs a chat completion request to Anthropic's API. -// It formats the request, sends it to Anthropic, and processes the response. -// Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) - if err != nil { - return nil, err +// HandleGeminiResponsesStream handles streaming for Gemini-compatible APIs. +func HandleGeminiResponsesStream( + ctx context.Context, + client *fasthttp.Client, + url string, + jsonBody []byte, + headers map[string]string, + extraHeaders map[string]string, + sendBackRawResponse bool, + providerName schemas.ModelProvider, + model string, + postHookRunner schemas.PostHookRunner, + postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, + logger schemas.Logger, +) (chan *schemas.BifrostStream, *schemas.BifrostError) { + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + resp.StreamBody = true + defer fasthttp.ReleaseRequest(req) + + req.Header.SetMethod(http.MethodPost) + req.SetRequestURI(url) + req.Header.SetContentType("application/json") + providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil) + + // Set headers + for key, value := range headers { + req.Header.Set(key, value) } - response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model + req.SetBody(jsonBody) - return response, nil -} + // Make the request + doErr := client.Do(req, resp) + if doErr != nil { + defer providerUtils.ReleaseStreamingResponse(resp) + if errors.Is(doErr, context.Canceled) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Type: schemas.Ptr(schemas.RequestCancelled), + Message: schemas.ErrRequestCancelled, + Error: doErr, + }, + } + } + if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestTimedOut, doErr, providerName) + } + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName) + } -// ResponsesStream performs a streaming responses request to the Gemini API. -func (provider *GeminiProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) - return provider.ChatCompletionStream( - ctx, - postHookRunner, - key, - request.ToChatRequest(), - ) + // Check for HTTP errors + if resp.StatusCode() != fasthttp.StatusOK { + defer providerUtils.ReleaseStreamingResponse(resp) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerName, resp.StatusCode()), fmt.Errorf("%s", string(resp.Body())), resp.StatusCode(), providerName, nil, nil) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer providerUtils.ReleaseStreamingResponse(resp) + + scanner := bufio.NewScanner(resp.BodyStream()) + buf := make([]byte, 0, 1024*1024) + scanner.Buffer(buf, 10*1024*1024) + + chunkIndex := 0 + sequenceNumber := 0 // Track sequence across all events + startTime := time.Now() + lastChunkTime := startTime + + // Initialize stream state for responses lifecycle management + streamState := acquireGeminiResponsesStreamState() + defer releaseGeminiResponsesStreamState(streamState) + + var lastUsageMetadata *GenerateContentResponseUsageMetadata + + for scanner.Scan() { + line := scanner.Text() + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE data + if !strings.HasPrefix(line, "data: ") { + continue + } + + eventData := strings.TrimPrefix(line, "data: ") + + // Skip empty data + if strings.TrimSpace(eventData) == "" { + continue + } + + // Process chunk using shared function + geminiResponse, err := processGeminiStreamChunk(eventData) + if err != nil { + if strings.Contains(err.Error(), "gemini api error") { + // Handle API error + bifrostErr := &schemas.BifrostError{ + Type: schemas.Ptr("gemini_api_error"), + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.ResponsesStreamRequest, + Provider: providerName, + ModelRequested: model, + }, + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + return + } + logger.Warn(fmt.Sprintf("Failed to process chunk: %v", err)) + continue + } + + // Track usage metadata from the latest chunk + if geminiResponse.UsageMetadata != nil { + lastUsageMetadata = geminiResponse.UsageMetadata + } + + // Convert to Bifrost responses stream response + responses, bifrostErr := geminiResponse.ToBifrostResponsesStream(sequenceNumber, streamState) + if bifrostErr != nil { + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: schemas.ResponsesStreamRequest, + Provider: providerName, + ModelRequested: model, + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + return + } + + for i, response := range responses { + if response != nil { + response.ExtraFields = schemas.BifrostResponseExtraFields{ + RequestType: schemas.ResponsesStreamRequest, + Provider: providerName, + ModelRequested: model, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + } + + if postResponseConverter != nil { + response = postResponseConverter(response) + if response == nil { + logger.Warn("postResponseConverter returned nil; skipping chunk") + continue + } + } + + // Only add raw response to the LAST response in the array + if sendBackRawResponse && i == len(responses)-1 { + response.ExtraFields.RawResponse = eventData + } + + chunkIndex++ + sequenceNumber++ // Increment sequence number for each response + + // Check if this is the last chunk + isLastChunk := false + if response.Type == schemas.ResponsesStreamResponseTypeCompleted { + isLastChunk = true + } + + if isLastChunk { + response.ExtraFields.Latency = time.Since(startTime).Milliseconds() + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + return + } + + // For multiple responses in one event, only update timing on the last one + if i == len(responses)-1 { + lastChunkTime = time.Now() + } + + // Process response through post-hooks and send to channel + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + } + } + } + + // Handle scanner errors + if err := scanner.Err(); err != nil { + logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, model, logger) + } else { + // Finalize the stream by closing any open items + finalResponses := FinalizeGeminiResponsesStream(streamState, lastUsageMetadata, sequenceNumber) + for i, finalResponse := range finalResponses { + finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + RequestType: schemas.ResponsesStreamRequest, + Provider: providerName, + ModelRequested: model, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + } + + if postResponseConverter != nil { + finalResponse = postResponseConverter(finalResponse) + if finalResponse == nil { + logger.Warn("postResponseConverter returned nil; skipping final response") + continue + } + } + + chunkIndex++ + sequenceNumber++ + + if sendBackRawResponse { + finalResponse.ExtraFields.RawResponse = "{}" // Final event has no payload + } + + // Set final latency on the last response (completed event) + if i == len(finalResponses)-1 { + finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() + } + + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan) + } + } + }() + + return responseChan, nil } // Embedding performs an embedding request to the Gemini API. diff --git a/core/providers/gemini/gemini_test.go b/core/providers/gemini/gemini_test.go index 0f8ac24db..0ecbe2204 100644 --- a/core/providers/gemini/gemini_test.go +++ b/core/providers/gemini/gemini_test.go @@ -32,7 +32,7 @@ func TestGemini(t *testing.T) { SpeechSynthesisFallbacks: []schemas.Fallback{ {Provider: schemas.Gemini, Model: "gemini-2.5-pro-preview-tts"}, }, - ReasoningModel: "gemini-2.5-pro", + ReasoningModel: "gemini-3-pro-preview", Scenarios: testutil.TestScenarios{ TextCompletion: false, // Not supported SimpleChat: true, @@ -52,7 +52,7 @@ func TestGemini(t *testing.T) { TranscriptionStream: false, SpeechSynthesis: true, SpeechSynthesisStream: true, - Reasoning: false, //TODO: Supported but lost since we map Gemini's responses via chat completions, fix is a native Gemini handler or reasoning support in chat completions + Reasoning: true, ListModels: true, }, } diff --git a/core/providers/gemini/responses.go b/core/providers/gemini/responses.go index 8c2942c11..72031a8cb 100644 --- a/core/providers/gemini/responses.go +++ b/core/providers/gemini/responses.go @@ -5,14 +5,72 @@ import ( "encoding/json" "fmt" "strings" + "sync" + "time" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" ) -func ToGeminiResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*GeminiGenerationRequest, error) { +func (request *GeminiGenerationRequest) ToBifrostResponsesRequest() *schemas.BifrostResponsesRequest { + if request == nil { + return nil + } + + provider, model := schemas.ParseModelString(request.Model, schemas.Gemini) + + // Create the BifrostResponsesRequest + bifrostReq := &schemas.BifrostResponsesRequest{ + Provider: provider, + Model: model, + } + + params := request.convertGenerationConfigToResponsesParameters() + + // Convert Contents to Input messages + if len(request.Contents) > 0 { + bifrostReq.Input = convertGeminiContentsToResponsesMessages(request.Contents) + } + + if request.SystemInstruction != nil { + var systemInstructionText string + if len(request.SystemInstruction.Parts) > 0 { + for _, part := range request.SystemInstruction.Parts { + if part.Text != "" { + systemInstructionText += part.Text + } + } + } + if systemInstructionText != "" { + params.Instructions = &systemInstructionText + } + } + + if len(request.Tools) > 0 { + params.Tools = convertGeminiToolsToResponsesTools(request.Tools) + } + + if request.ToolConfig.FunctionCallingConfig != nil { + params.ToolChoice = convertGeminiToolConfigToToolChoice(request.ToolConfig) + } + + if request.SafetySettings != nil { + params.ExtraParams["safety_settings"] = request.SafetySettings + } + + if request.CachedContent != "" { + params.ExtraParams["cached_content"] = request.CachedContent + } + + bifrostReq.Params = params + + return bifrostReq + +} + +func ToGeminiResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *GeminiGenerationRequest { if bifrostReq == nil { - return nil, nil + return nil } // Create the base Gemini generation request @@ -22,7 +80,7 @@ func ToGeminiResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*Gem // Convert parameters to generation config if bifrostReq.Params != nil { - geminiReq.GenerationConfig = convertParamsToGenerationConfigResponses(bifrostReq.Params) + geminiReq.GenerationConfig = geminiReq.convertParamsToGenerationConfigResponses(bifrostReq.Params) // Handle tool-related parameters if len(bifrostReq.Params.Tools) > 0 { @@ -39,7 +97,7 @@ func ToGeminiResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*Gem if bifrostReq.Input != nil { contents, systemInstruction, err := convertResponsesMessagesToGeminiContents(bifrostReq.Input) if err != nil { - return nil, fmt.Errorf("failed to convert messages: %w", err) + return nil } geminiReq.Contents = contents @@ -48,7 +106,18 @@ func ToGeminiResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*Gem } } - return geminiReq, nil + if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil { + if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok { + if settings, ok := safetySettings.([]SafetySetting); ok { + geminiReq.SafetySettings = settings + } + } + if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok { + geminiReq.CachedContent = cachedContent + } + } + + return geminiReq } // ToResponsesBifrostResponsesResponse converts a Gemini GenerateContentResponse to a BifrostResponsesResponse @@ -57,25 +126,13 @@ func (response *GenerateContentResponse) ToResponsesBifrostResponsesResponse() * return nil } - // Parse model string to get provider and model - // Create the BifrostResponse with Responses structure - bifrostResp := &schemas.BifrostResponsesResponse{} + bifrostResp := &schemas.BifrostResponsesResponse{ + Model: response.ModelVersion, + } // Convert usage information - if response.UsageMetadata != nil { - bifrostResp.Usage = &schemas.ResponsesResponseUsage{ - TotalTokens: int(response.UsageMetadata.TotalTokenCount), - InputTokens: int(response.UsageMetadata.PromptTokenCount), - OutputTokens: int(response.UsageMetadata.CandidatesTokenCount), - InputTokensDetails: &schemas.ResponsesResponseInputTokens{}, - } - - // Handle cached tokens if present - if response.UsageMetadata.CachedContentTokenCount > 0 { - bifrostResp.Usage.InputTokensDetails.CachedTokens = int(response.UsageMetadata.CachedContentTokenCount) - } - } + bifrostResp.Usage = convertGeminiUsageMetadataToResponsesUsage(response.UsageMetadata) // Convert candidates to Responses output messages if len(response.Candidates) > 0 { @@ -88,6 +145,1506 @@ func (response *GenerateContentResponse) ToResponsesBifrostResponsesResponse() * return bifrostResp } +func ToGeminiResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) *GenerateContentResponse { + if bifrostResp == nil { + return nil + } + + geminiResp := &GenerateContentResponse{ + ModelVersion: bifrostResp.Model, + } + + // Set response ID if available + if bifrostResp.ID != nil { + geminiResp.ResponseID = *bifrostResp.ID + } + + // Set creation time + if bifrostResp.CreatedAt > 0 { + geminiResp.CreateTime = time.Unix(int64(bifrostResp.CreatedAt), 0) + } + + // Convert output messages to candidates + if len(bifrostResp.Output) > 0 { + candidates := []*Candidate{} + + // Group messages by their role to create candidates + var currentParts []*Part + var currentRole string + + for _, msg := range bifrostResp.Output { + // Determine the role + role := "model" // default + if msg.Role != nil { + switch *msg.Role { + case schemas.ResponsesInputMessageRoleAssistant: + role = "model" + case schemas.ResponsesInputMessageRoleUser: + role = "user" + default: + role = "model" + } + } + + // If we're starting a new candidate (role changed), save the previous one + if currentRole != "" && currentRole != role && len(currentParts) > 0 { + candidates = append(candidates, &Candidate{ + Index: int32(len(candidates)), + Content: &Content{ + Parts: currentParts, + Role: currentRole, + }, + }) + currentParts = []*Part{} + } + currentRole = role + + // Convert message content to parts + if msg.Content != nil { + // Handle string content + if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" { + currentParts = append(currentParts, &Part{ + Text: *msg.Content.ContentStr, + }) + } + + // Handle content blocks + if msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + part, err := convertContentBlockToGeminiPart(block) + if err == nil && part != nil { + currentParts = append(currentParts, part) + } + } + } + } + + // Handle tool calls (function calls) + if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeFunctionCall && msg.ResponsesToolMessage != nil { + argsMap := make(map[string]any) + if msg.ResponsesToolMessage.Arguments != nil { + if err := sonic.Unmarshal([]byte(*msg.ResponsesToolMessage.Arguments), &argsMap); err == nil { + functionCall := &FunctionCall{ + Args: argsMap, + } + if msg.ResponsesToolMessage.Name != nil { + functionCall.Name = *msg.ResponsesToolMessage.Name + } + if msg.ResponsesToolMessage.CallID != nil { + functionCall.ID = *msg.ResponsesToolMessage.CallID + } + + part := &Part{ + FunctionCall: functionCall, + } + + // Check for thought signature in reasoning message + if msg.ResponsesReasoning != nil && msg.ResponsesReasoning.EncryptedContent != nil { + part.ThoughtSignature = []byte(*msg.ResponsesReasoning.EncryptedContent) + } + + currentParts = append(currentParts, part) + } + } + } + + // Handle function responses (function call outputs) + if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeFunctionCallOutput && msg.ResponsesToolMessage != nil { + responseMap := make(map[string]any) + + if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { + responseMap["output"] = *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr + } + funcName := "" + if msg.ResponsesToolMessage.Name != nil && strings.TrimSpace(*msg.ResponsesToolMessage.Name) != "" { + funcName = *msg.ResponsesToolMessage.Name + } else if msg.ResponsesToolMessage.CallID != nil { + funcName = *msg.ResponsesToolMessage.CallID + } + + functionResponse := &FunctionResponse{ + Name: funcName, + Response: responseMap, + } + if msg.ResponsesToolMessage.CallID != nil { + functionResponse.ID = *msg.ResponsesToolMessage.CallID + } + + currentParts = append(currentParts, &Part{ + FunctionResponse: functionResponse, + }) + } + + // Handle reasoning messages + if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeReasoning && msg.ResponsesReasoning != nil { + // Reasoning content is in the Summary array + if len(msg.ResponsesReasoning.Summary) > 0 { + for _, summaryBlock := range msg.ResponsesReasoning.Summary { + if summaryBlock.Text != "" { + currentParts = append(currentParts, &Part{ + Text: summaryBlock.Text, + Thought: true, + }) + } + } + } + if msg.ResponsesReasoning.EncryptedContent != nil { + currentParts = append(currentParts, &Part{ + ThoughtSignature: []byte(*msg.ResponsesReasoning.EncryptedContent), + }) + } + } + } + + // Add the last candidate if we have parts + if len(currentParts) > 0 { + candidate := &Candidate{ + Index: int32(len(candidates)), + Content: &Content{ + Parts: currentParts, + Role: currentRole, + }, + } + + // Determine finish reason based on incomplete details + if bifrostResp.IncompleteDetails != nil { + switch bifrostResp.IncompleteDetails.Reason { + case "max_tokens": + candidate.FinishReason = FinishReasonMaxTokens + case "content_filter": + candidate.FinishReason = FinishReasonSafety + default: + candidate.FinishReason = FinishReasonOther + } + } else { + candidate.FinishReason = FinishReasonStop + } + + candidates = append(candidates, candidate) + } + + geminiResp.Candidates = candidates + } + + // Convert usage metadata + if bifrostResp.Usage != nil { + geminiResp.UsageMetadata = &GenerateContentResponseUsageMetadata{ + PromptTokenCount: int32(bifrostResp.Usage.InputTokens), + CandidatesTokenCount: int32(bifrostResp.Usage.OutputTokens), + TotalTokenCount: int32(bifrostResp.Usage.TotalTokens), + } + if bifrostResp.Usage.OutputTokensDetails != nil { + geminiResp.UsageMetadata.ThoughtsTokenCount = int32(bifrostResp.Usage.OutputTokensDetails.ReasoningTokens) + } + } + + return geminiResp +} + +func ToGeminiResponsesStreamResponse(bifrostResp *schemas.BifrostResponsesStreamResponse) *GenerateContentResponse { + if bifrostResp == nil { + return nil + } + + // Skip lifecycle events that don't have corresponding Gemini equivalents + switch bifrostResp.Type { + case schemas.ResponsesStreamResponseTypePing, + schemas.ResponsesStreamResponseTypeCreated, + schemas.ResponsesStreamResponseTypeInProgress, + schemas.ResponsesStreamResponseTypeReasoningSummaryPartAdded, + schemas.ResponsesStreamResponseTypeQueued: + // These are lifecycle events with no Gemini equivalent + return nil + } + + streamResp := &GenerateContentResponse{ + Candidates: []*Candidate{ + { + Content: &Content{ + Parts: []*Part{}, + Role: "model", + }, + }, + }, + } + + candidate := streamResp.Candidates[0] + + switch bifrostResp.Type { + case schemas.ResponsesStreamResponseTypeOutputTextDelta: + if bifrostResp.Delta != nil && *bifrostResp.Delta != "" { + candidate.Content.Parts = append(candidate.Content.Parts, &Part{ + Text: *bifrostResp.Delta, + }) + } + + case schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta: + if bifrostResp.Delta != nil && *bifrostResp.Delta != "" { + candidate.Content.Parts = append(candidate.Content.Parts, &Part{ + Text: *bifrostResp.Delta, + Thought: true, + }) + } + + case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta: + // For streaming, we'll accumulate these, but Gemini typically sends complete calls + // We'll return nil here and let the done event handle it + return nil + + // Function call completed + case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone: + if bifrostResp.Item != nil && bifrostResp.Item.ResponsesToolMessage != nil { + argsMap := make(map[string]any) + if bifrostResp.Item.ResponsesToolMessage.Arguments != nil { + if err := sonic.Unmarshal([]byte(*bifrostResp.Item.ResponsesToolMessage.Arguments), &argsMap); err == nil { + functionCall := &FunctionCall{ + Name: "", + Args: argsMap, + } + if bifrostResp.Item.ResponsesToolMessage.Name != nil { + functionCall.Name = *bifrostResp.Item.ResponsesToolMessage.Name + } + if bifrostResp.Item.ResponsesToolMessage.CallID != nil { + functionCall.ID = *bifrostResp.Item.ResponsesToolMessage.CallID + } + candidate.Content.Parts = append(candidate.Content.Parts, &Part{ + FunctionCall: functionCall, + }) + } + } + } + + case schemas.ResponsesStreamResponseTypeOutputTextDone: + if bifrostResp.Text != nil && *bifrostResp.Text != "" { + candidate.Content.Parts = append(candidate.Content.Parts, &Part{ + Text: *bifrostResp.Text, + }) + } + + case schemas.ResponsesStreamResponseTypeReasoningSummaryTextDone, + schemas.ResponsesStreamResponseTypeReasoningSummaryPartDone: + // Already handled via deltas, skip + return nil + case schemas.ResponsesStreamResponseTypeOutputItemAdded: + if bifrostResp.Item != nil && bifrostResp.Item.ResponsesReasoning != nil && bifrostResp.Item.EncryptedContent != nil { + candidate.Content.Parts = append(candidate.Content.Parts, &Part{ + ThoughtSignature: []byte(*bifrostResp.Item.ResponsesReasoning.EncryptedContent), + }) + } + + case schemas.ResponsesStreamResponseTypeOutputItemDone: + return nil + + case schemas.ResponsesStreamResponseTypeContentPartAdded: + // Handle content parts that contain images, audio, or files + if bifrostResp.Part != nil { + part, err := convertContentBlockToGeminiPart(*bifrostResp.Part) + if err == nil && part != nil { + candidate.Content.Parts = append(candidate.Content.Parts, part) + } + } + + case schemas.ResponsesStreamResponseTypeContentPartDone: + // Already handled via ContentPartAdded + return nil + + case schemas.ResponsesStreamResponseTypeCompleted: + if bifrostResp.Response != nil { + // Set model version if available + if bifrostResp.Response.Model != "" { + streamResp.ModelVersion = bifrostResp.Response.Model + } + + // Convert usage metadata if available + if bifrostResp.Response.Usage != nil { + streamResp.UsageMetadata = &GenerateContentResponseUsageMetadata{ + PromptTokenCount: int32(bifrostResp.Response.Usage.InputTokens), + CandidatesTokenCount: int32(bifrostResp.Response.Usage.OutputTokens), + TotalTokenCount: int32(bifrostResp.Response.Usage.TotalTokens), + } + if bifrostResp.Response.Usage.InputTokensDetails != nil { + streamResp.UsageMetadata.CachedContentTokenCount = int32(bifrostResp.Response.Usage.InputTokensDetails.CachedTokens) + } + if bifrostResp.Response.Usage.OutputTokensDetails != nil { + streamResp.UsageMetadata.ThoughtsTokenCount = int32(bifrostResp.Response.Usage.OutputTokensDetails.ReasoningTokens) + } + if bifrostResp.Response.Usage.OutputTokensDetails != nil && bifrostResp.Response.Usage.OutputTokensDetails.AudioTokens > 0 { + // Store audio tokens separately or add proper field + streamResp.UsageMetadata.CandidatesTokensDetails = append(streamResp.UsageMetadata.CandidatesTokensDetails, &ModalityTokenCount{ + Modality: "AUDIO", + TokenCount: int32(bifrostResp.Response.Usage.OutputTokensDetails.AudioTokens), + }) + } + } + + // Set finish reason + candidate.FinishReason = FinishReasonStop + } + + // Response failed + case schemas.ResponsesStreamResponseTypeFailed: + candidate.FinishReason = FinishReasonOther + if bifrostResp.Response != nil && bifrostResp.Response.Error != nil { + streamResp.PromptFeedback = &GenerateContentResponsePromptFeedback{ + BlockReason: "ERROR", + BlockReasonMessage: bifrostResp.Response.Error.Message, + } + } + + // Refusal + case schemas.ResponsesStreamResponseTypeRefusalDelta: + if bifrostResp.Delta != nil && *bifrostResp.Delta != "" { + candidate.Content.Parts = append(candidate.Content.Parts, &Part{ + Text: *bifrostResp.Delta, + }) + } + + case schemas.ResponsesStreamResponseTypeRefusalDone: + if bifrostResp.Refusal != nil && *bifrostResp.Refusal != "" { + candidate.FinishReason = FinishReasonSafety + } + + default: + // For any other event types we don't explicitly handle, return nil + return nil + } + + // If we didn't add any parts and there's no metadata, return nil + if len(candidate.Content.Parts) == 0 && streamResp.UsageMetadata == nil && + streamResp.PromptFeedback == nil && candidate.FinishReason == "" { + return nil + } + + return streamResp +} + +// GeminiResponsesStreamState tracks state during streaming conversion for responses API +type GeminiResponsesStreamState struct { + // Lifecycle flags + HasEmittedCreated bool // Whether response.created has been sent + HasEmittedInProgress bool // Whether response.in_progress has been sent + HasEmittedCompleted bool // Whether response.completed has been sent + + // Item tracking + CurrentOutputIndex int // Current output index + ItemIDs map[int]string // Maps output_index to item ID + TextItemClosed bool // Whether text item has been closed + + // Tool call tracking + ToolCallIDs map[int]string // Maps output_index to tool call ID + ToolCallNames map[int]string // Maps output_index to tool name + ToolArgumentBuffers map[int]string // Accumulates tool arguments as JSON + + // Response metadata + MessageID *string // Generated message ID + Model *string // Model version + CreatedAt int // Timestamp for consistency + ResponseID *string // Gemini's responseId + + // Content tracking + HasStartedText bool // Whether we've started text content + HasStartedToolCall bool // Whether we've started a tool call +} + +// geminiResponsesStreamStatePool provides a pool for Gemini responses stream state objects. +var geminiResponsesStreamStatePool = sync.Pool{ + New: func() interface{} { + return &GeminiResponsesStreamState{ + ItemIDs: make(map[int]string), + ToolCallIDs: make(map[int]string), + ToolCallNames: make(map[int]string), + ToolArgumentBuffers: make(map[int]string), + CurrentOutputIndex: 0, + CreatedAt: int(time.Now().Unix()), + HasEmittedCreated: false, + HasEmittedInProgress: false, + HasEmittedCompleted: false, + TextItemClosed: false, + HasStartedText: false, + HasStartedToolCall: false, + } + }, +} + +// acquireGeminiResponsesStreamState gets a Gemini responses stream state from the pool. +func acquireGeminiResponsesStreamState() *GeminiResponsesStreamState { + state := geminiResponsesStreamStatePool.Get().(*GeminiResponsesStreamState) + // Clear maps + if state.ItemIDs == nil { + state.ItemIDs = make(map[int]string) + } else { + clear(state.ItemIDs) + } + if state.ToolCallIDs == nil { + state.ToolCallIDs = make(map[int]string) + } else { + clear(state.ToolCallIDs) + } + if state.ToolCallNames == nil { + state.ToolCallNames = make(map[int]string) + } else { + clear(state.ToolCallNames) + } + if state.ToolArgumentBuffers == nil { + state.ToolArgumentBuffers = make(map[int]string) + } else { + clear(state.ToolArgumentBuffers) + } + // Reset other fields + state.CurrentOutputIndex = 0 + state.MessageID = nil + state.Model = nil + state.ResponseID = nil + state.CreatedAt = int(time.Now().Unix()) + state.HasEmittedCreated = false + state.HasEmittedInProgress = false + state.HasEmittedCompleted = false + state.TextItemClosed = false + state.HasStartedText = false + state.HasStartedToolCall = false + return state +} + +// releaseGeminiResponsesStreamState returns a Gemini responses stream state to the pool. +func releaseGeminiResponsesStreamState(state *GeminiResponsesStreamState) { + if state != nil { + state.flush() + geminiResponsesStreamStatePool.Put(state) + } +} + +func (state *GeminiResponsesStreamState) flush() { + // Clear maps + if state.ItemIDs == nil { + state.ItemIDs = make(map[int]string) + } else { + clear(state.ItemIDs) + } + if state.ToolCallIDs == nil { + state.ToolCallIDs = make(map[int]string) + } else { + clear(state.ToolCallIDs) + } + if state.ToolCallNames == nil { + state.ToolCallNames = make(map[int]string) + } else { + clear(state.ToolCallNames) + } + if state.ToolArgumentBuffers == nil { + state.ToolArgumentBuffers = make(map[int]string) + } else { + clear(state.ToolArgumentBuffers) + } + state.CurrentOutputIndex = 0 + state.MessageID = nil + state.Model = nil + state.ResponseID = nil + state.CreatedAt = int(time.Now().Unix()) + state.HasEmittedCreated = false + state.HasEmittedCompleted = false + state.HasEmittedInProgress = false + state.TextItemClosed = false + state.HasStartedText = false + state.HasStartedToolCall = false +} + +// ToBifrostResponsesStream converts a Gemini stream event to Bifrost Responses Stream responses +func (response *GenerateContentResponse) ToBifrostResponsesStream(sequenceNumber int, state *GeminiResponsesStreamState) ([]*schemas.BifrostResponsesStreamResponse, *schemas.BifrostError) { + var responses []*schemas.BifrostResponsesStreamResponse + + // First event: Emit response.created and response.in_progress + if !state.HasEmittedCreated { + // Generate message ID + if state.MessageID == nil { + messageID := fmt.Sprintf("msg_%d", state.CreatedAt) + state.MessageID = &messageID + } + + // Set model and response ID from Gemini + if response.ModelVersion != "" && state.Model == nil { + state.Model = &response.ModelVersion + } + if response.ResponseID != "" && state.ResponseID == nil { + state.ResponseID = &response.ResponseID + } + + // Emit response.created + createdResp := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, + } + if state.Model != nil { + createdResp.Model = *state.Model + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCreated, + SequenceNumber: sequenceNumber + len(responses), + Response: createdResp, + }) + state.HasEmittedCreated = true + + // Emit response.in_progress + inProgressResp := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, + } + if state.Model != nil { + inProgressResp.Model = *state.Model + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeInProgress, + SequenceNumber: sequenceNumber + len(responses), + Response: inProgressResp, + }) + state.HasEmittedInProgress = true + } + + // Process candidates + if len(response.Candidates) > 0 { + candidate := response.Candidates[0] + if candidate.Content != nil && len(candidate.Content.Parts) > 0 { + for _, part := range candidate.Content.Parts { + partResponses := processGeminiPart(part, state, sequenceNumber+len(responses)) + responses = append(responses, partResponses...) + } + } + + // Check for finish reason (indicates end of generation) + // Only close if we've actually started emitting content (text, tool calls, etc.) + // This prevents emitting response.completed for empty chunks with just finishReason + if candidate.FinishReason != "" && (state.HasStartedText || state.HasStartedToolCall) { + // Close any open items + closeResponses := closeGeminiOpenItems(state, response.UsageMetadata, sequenceNumber+len(responses)) + responses = append(responses, closeResponses...) + } + } + + return responses, nil +} + +// processGeminiPart processes a single Gemini part and returns appropriate lifecycle events +func processGeminiPart(part *Part, state *GeminiResponsesStreamState, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse + + switch { + case part.Text != "" && !part.Thought: + // Regular text content + responses = append(responses, processGeminiTextPart(part, state, sequenceNumber)...) + + case part.Thought && part.Text != "": + // Reasoning/thinking content + responses = append(responses, processGeminiThoughtPart(part, state, sequenceNumber)...) + + case part.ThoughtSignature != nil: + // Encrypted reasoning content (thoughtSignature) + responses = append(responses, processGeminiThoughtSignaturePart(part, state, sequenceNumber)...) + + case part.FunctionCall != nil: + // Function call + responses = append(responses, processGeminiFunctionCallPart(part, state, sequenceNumber)...) + case part.FunctionResponse != nil: + // Function response (tool result) + responses = append(responses, processGeminiFunctionResponsePart(part, state, sequenceNumber)...) + case part.InlineData != nil: + // Inline data + responses = append(responses, processGeminiInlineDataPart(part, state, sequenceNumber)...) + case part.FileData != nil: + // File data + responses = append(responses, processGeminiFileDataPart(part, state, sequenceNumber)...) + } + + return responses +} + +// processGeminiTextPart handles regular text parts +func processGeminiTextPart(part *Part, state *GeminiResponsesStreamState, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse + + // If this is the first text, emit output_item.added and content_part.added + if !state.HasStartedText { + outputIndex := 0 + state.CurrentOutputIndex = outputIndex + + itemID := fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) + state.ItemIDs[outputIndex] = itemID + + // Emit output_item.added + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{}, + }, + }, + }) + + // Emit content_part.added + contentIndex := 0 + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ContentIndex: &contentIndex, + ItemID: &itemID, + Part: &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: schemas.Ptr(""), + }, + }) + + state.HasStartedText = true + } + + // Emit output_text.delta for the text content + if part.Text != "" { + outputIndex := 0 + itemID := state.ItemIDs[outputIndex] + contentIndex := 0 + text := part.Text + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ContentIndex: &contentIndex, + ItemID: &itemID, + Delta: &text, + }) + } + + return responses +} + +// processGeminiThoughtPart handles reasoning/thought parts +func processGeminiThoughtPart(part *Part, state *GeminiResponsesStreamState, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse + + // Close text item if open + if state.HasStartedText && !state.TextItemClosed { + responses = append(responses, closeGeminiTextItem(state, sequenceNumber)...) + } + + // For Gemini thoughts/reasoning, we emit them as reasoning summary text deltas + // Initialize reasoning item if not already done + outputIndex := state.CurrentOutputIndex + 1 + if !state.HasStartedText { + outputIndex = 1 + } + state.CurrentOutputIndex = outputIndex + + itemID := fmt.Sprintf("msg_%s_reasoning_%d", *state.MessageID, outputIndex) + state.ItemIDs[outputIndex] = itemID + + // Emit output_item.added for reasoning + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + }, + }) + + // Emit reasoning summary part added + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + }) + + // Emit reasoning summary text delta with the thought content + if part.Text != "" { + text := part.Text + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Delta: &text, + }) + } + + // Emit reasoning summary text done + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + }) + + // Emit reasoning summary part done + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + }) + + // Emit output_item.done for reasoning + statusCompleted := "completed" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Status: &statusCompleted, + }, + }) + + return responses +} + +// processGeminiThoughtSignaturePart handles encrypted reasoning content (thoughtSignature) +func processGeminiThoughtSignaturePart(part *Part, state *GeminiResponsesStreamState, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse + + // Close text item if open + if state.HasStartedText && !state.TextItemClosed { + responses = append(responses, closeGeminiTextItem(state, sequenceNumber)...) + } + + // Create a new reasoning item for the thought signature + outputIndex := state.CurrentOutputIndex + 1 + if !state.HasStartedText { + outputIndex = 1 + } + state.CurrentOutputIndex = outputIndex + + itemID := fmt.Sprintf("msg_%s_reasoning_%d", *state.MessageID, outputIndex) + state.ItemIDs[outputIndex] = itemID + + // Convert thoughtSignature to string + thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature) + + // Emit output_item.added for reasoning with encrypted content + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + EncryptedContent: &thoughtSig, + }, + }, + }) + + // Emit output_item.done for reasoning (thought signature is complete) + statusCompleted := "completed" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Status: &statusCompleted, + }, + }) + + return responses +} + +// processGeminiFunctionCallPart handles function call parts +func processGeminiFunctionCallPart(part *Part, state *GeminiResponsesStreamState, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse + + // Close text item if open + if state.HasStartedText && !state.TextItemClosed { + responses = append(responses, closeGeminiTextItem(state, sequenceNumber)...) + } + + // Start new function call item + outputIndex := state.CurrentOutputIndex + 1 + if !state.HasStartedText { + outputIndex = 1 // If no text, start at index 1 + } + state.CurrentOutputIndex = outputIndex + + toolUseID := part.FunctionCall.ID + if toolUseID == "" { + toolUseID = part.FunctionCall.Name // Fallback to name as ID + } + + state.ItemIDs[outputIndex] = toolUseID + state.ToolCallIDs[outputIndex] = toolUseID + state.ToolCallNames[outputIndex] = part.FunctionCall.Name + + // Convert args to JSON string + argsJSON := "" + if part.FunctionCall.Args != nil { + if argsBytes, err := json.Marshal(part.FunctionCall.Args); err == nil { + argsJSON = string(argsBytes) + } + } + state.ToolArgumentBuffers[outputIndex] = argsJSON + + // Emit output_item.added for function call + status := "in_progress" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &toolUseID, + Item: &schemas.ResponsesMessage{ + ID: &toolUseID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: &status, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &toolUseID, + Name: &part.FunctionCall.Name, + Arguments: &argsJSON, + }, + }, + }) + + // Gemini sends complete function calls, so immediately emit done event + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &toolUseID, + Arguments: &argsJSON, + Item: &schemas.ResponsesMessage{ + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &toolUseID, + Name: &part.FunctionCall.Name, + }, + }, + }) + + state.HasStartedToolCall = true + + return responses +} + +// processGeminiFunctionResponsePart handles function response (tool result) parts +func processGeminiFunctionResponsePart(part *Part, state *GeminiResponsesStreamState, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse + + // Close text item if open + if state.HasStartedText && !state.TextItemClosed { + responses = append(responses, closeGeminiTextItem(state, sequenceNumber)...) + } + + // Extract output from function response + output := extractFunctionResponseOutput(part.FunctionResponse) + + // Create new output item for the function response + outputIndex := state.CurrentOutputIndex + 1 + if !state.HasStartedText { + outputIndex = 0 + } + state.CurrentOutputIndex = outputIndex + + responseID := part.FunctionResponse.ID + if responseID == "" { + responseID = part.FunctionResponse.Name // Fallback to name + } + + itemID := fmt.Sprintf("func_resp_%s", responseID) + state.ItemIDs[outputIndex] = itemID + + // Emit output_item.added for function call output + status := "completed" + item := &schemas.ResponsesMessage{ + ID: &itemID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Status: &status, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &responseID, + Output: &schemas.ResponsesToolMessageOutputStruct{ + ResponsesToolCallOutputStr: &output, + }, + }, + } + + // Set tool name if present + if name := strings.TrimSpace(part.FunctionResponse.Name); name != "" { + item.ResponsesToolMessage.Name = schemas.Ptr(name) + } + + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: item, + }) + + // Immediately emit output_item.done since function responses are complete + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Status: &status, + }, + }) + + return responses +} + +// processGeminiInlineDataPart handles inline data parts +func processGeminiInlineDataPart(part *Part, state *GeminiResponsesStreamState, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse + + // Close text item if open + if state.HasStartedText && !state.TextItemClosed { + responses = append(responses, closeGeminiTextItem(state, sequenceNumber)...) + } + + // Convert inline data to content block + block := convertGeminiInlineDataToContentBlock(part.InlineData) + if block == nil { + return responses + } + + // Create new output item for the inline data + outputIndex := state.CurrentOutputIndex + 1 + if !state.HasStartedText { + outputIndex = 0 + } + state.CurrentOutputIndex = outputIndex + + itemID := fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) + state.ItemIDs[outputIndex] = itemID + + // Emit output_item.added with the inline data content block + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{*block}, + }, + }, + }) + + // Emit content_part.added + contentIndex := 0 + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ContentIndex: &contentIndex, + ItemID: &itemID, + Part: block, + }) + + // Emit content_part.done + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ContentIndex: &contentIndex, + ItemID: &itemID, + Part: block, + }) + + // Emit output_item.done + statusCompleted := "completed" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Status: &statusCompleted, + }, + }) + + return responses +} + +// processGeminiFileDataPart handles file data parts +func processGeminiFileDataPart(part *Part, state *GeminiResponsesStreamState, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse + + // Close text item if open + if state.HasStartedText && !state.TextItemClosed { + responses = append(responses, closeGeminiTextItem(state, sequenceNumber)...) + } + + // Convert file data to content block + block := convertGeminiFileDataToContentBlock(part.FileData) + if block == nil { + return responses + } + + // Create new output item for the file data + outputIndex := state.CurrentOutputIndex + 1 + if !state.HasStartedText { + outputIndex = 0 + } + state.CurrentOutputIndex = outputIndex + + itemID := fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) + state.ItemIDs[outputIndex] = itemID + + // Emit output_item.added with the file data content block + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{*block}, + }, + }, + }) + + // Emit content_part.added + contentIndex := 0 + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ContentIndex: &contentIndex, + ItemID: &itemID, + Part: block, + }) + + // Emit content_part.done + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ContentIndex: &contentIndex, + ItemID: &itemID, + Part: block, + }) + + // Emit output_item.done + statusCompleted := "completed" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Status: &statusCompleted, + }, + }) + + return responses +} + +// closeGeminiTextItem closes the text item and emits appropriate done events +func closeGeminiTextItem(state *GeminiResponsesStreamState, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse + + outputIndex := 0 + itemID := state.ItemIDs[outputIndex] + contentIndex := 0 + + // Emit output_text.done + emptyText := "" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ContentIndex: &contentIndex, + ItemID: &itemID, + Text: &emptyText, + }) + + // Emit content_part.done + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ContentIndex: &contentIndex, + ItemID: &itemID, + }) + + // Emit output_item.done + doneItem := &schemas.ResponsesMessage{ + Status: schemas.Ptr("completed"), + } + if itemID != "" { + doneItem.ID = &itemID + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: doneItem, + }) + + state.TextItemClosed = true + + return responses +} + +// closeGeminiOpenItems closes any open items and emits the final completed event +func closeGeminiOpenItems(state *GeminiResponsesStreamState, usage *GenerateContentResponseUsageMetadata, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + if state.HasEmittedCompleted { + return nil + } + + var responses []*schemas.BifrostResponsesStreamResponse + + // Close text item if still open + if state.HasStartedText && !state.TextItemClosed { + responses = append(responses, closeGeminiTextItem(state, sequenceNumber)...) + } + + // Close any open tool calls + for outputIndex := range state.ToolArgumentBuffers { + itemID := state.ItemIDs[outputIndex] + + // Emit output_item.done for tool call + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: &outputIndex, + ItemID: &itemID, + Item: &schemas.ResponsesMessage{ + ID: &itemID, + Status: schemas.Ptr("completed"), + }, + }) + } + + // Emit response.completed with usage + bifrostUsage := convertGeminiUsageMetadataToResponsesUsage(usage) + + completedResp := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, + Usage: bifrostUsage, + } + if state.Model != nil { + completedResp.Model = *state.Model + } + + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCompleted, + SequenceNumber: sequenceNumber + len(responses), + Response: completedResp, + }) + + state.HasEmittedCompleted = true + + return responses +} + +// FinalizeGeminiResponsesStream finalizes the stream by closing any open items and emitting completed event +func FinalizeGeminiResponsesStream(state *GeminiResponsesStreamState, usage *GenerateContentResponseUsageMetadata, sequenceNumber int) []*schemas.BifrostResponsesStreamResponse { + return closeGeminiOpenItems(state, usage, sequenceNumber) +} + +func convertGeminiContentsToResponsesMessages(contents []Content) []schemas.ResponsesMessage { + var messages []schemas.ResponsesMessage + // Track function call IDs by name to match with responses + functionCallIDs := make(map[string]string) + + for _, content := range contents { + msg := schemas.ResponsesMessage{} + + switch content.Role { + case "model": + msg.Role = schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant) + case "user": + msg.Role = schemas.Ptr(schemas.ResponsesInputMessageRoleUser) + default: + // Default to user for unknown roles + msg.Role = schemas.Ptr(schemas.ResponsesInputMessageRoleUser) + } + + // Convert parts to content blocks + if len(content.Parts) > 0 { + msg.Content = &schemas.ResponsesMessageContent{} + + for _, part := range content.Parts { + // Handle text content + if part.Text != "" && !part.Thought { + block := schemas.ResponsesMessageContentBlock{ + Text: &part.Text, + } + if content.Role == "model" { + block.Type = schemas.ResponsesOutputMessageContentTypeText + } else { + block.Type = schemas.ResponsesInputMessageContentBlockTypeText + } + msg.Content.ContentBlocks = append(msg.Content.ContentBlocks, block) + } + + // Handle thought/reasoning content + if part.Thought && part.Text != "" { + msgType := schemas.ResponsesMessageTypeReasoning + msg.Type = &msgType + msg.Role = nil + // Add reasoning content as text block + block := schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: &part.Text, + } + msg.Content.ContentBlocks = append(msg.Content.ContentBlocks, block) + } + + // Handle inline data (images, audio, files) + if part.InlineData != nil { + block := convertGeminiInlineDataToContentBlock(part.InlineData) + if block != nil { + msg.Content.ContentBlocks = append(msg.Content.ContentBlocks, *block) + } + } + + // Handle file data (URI-based) + if part.FileData != nil { + block := convertGeminiFileDataToContentBlock(part.FileData) + if block != nil { + msg.Content.ContentBlocks = append(msg.Content.ContentBlocks, *block) + } + } + + // Handle function calls + if part.FunctionCall != nil { + msgType := schemas.ResponsesMessageTypeFunctionCall + msg.Type = &msgType + msg.Role = nil + msg.Content = nil // Clear content for function calls + + argsJSON := "{}" + if part.FunctionCall.Args != nil { + if argsBytes, err := sonic.Marshal(part.FunctionCall.Args); err == nil { + argsJSON = string(argsBytes) + } + } + + callID := part.FunctionCall.ID + if callID == "" { + callID = part.FunctionCall.Name + } + + // Track this function call ID by name for later matching with responses + functionCallIDs[part.FunctionCall.Name] = callID + + msg.ResponsesToolMessage = &schemas.ResponsesToolMessage{ + CallID: &callID, + Name: &part.FunctionCall.Name, + Arguments: &argsJSON, + } + } + + // Handle function responses + if part.FunctionResponse != nil { + msgType := schemas.ResponsesMessageTypeFunctionCallOutput + msg.Type = &msgType + msg.Role = nil + msg.Content = nil // Clear content for function call outputs + + responseID := part.FunctionResponse.ID + if responseID == "" { + // Try to find the matching function call ID by name + if callID, ok := functionCallIDs[part.FunctionResponse.Name]; ok { + responseID = callID + } else { + // Fallback to function name if no matching call found + responseID = part.FunctionResponse.Name + } + } + + // Convert response map to string + responseStr := "" + if part.FunctionResponse.Response != nil { + if output, ok := part.FunctionResponse.Response["output"].(string); ok { + responseStr = output + } else if responseBytes, err := sonic.Marshal(part.FunctionResponse.Response); err == nil { + responseStr = string(responseBytes) + } + } + + msg.ResponsesToolMessage = &schemas.ResponsesToolMessage{ + CallID: &responseID, + Output: &schemas.ResponsesToolMessageOutputStruct{ + ResponsesToolCallOutputStr: &responseStr, + }, + } + } + } + } + + // Only append message if it has content or is a tool message + if msg.Content != nil || msg.ResponsesToolMessage != nil { + messages = append(messages, msg) + } + } + + return messages +} + +// convertGeminiInlineDataToContentBlock converts Gemini inline data (blob) to content block +func convertGeminiInlineDataToContentBlock(blob *Blob) *schemas.ResponsesMessageContentBlock { + if blob == nil { + return nil + } + + // Determine content type based on MIME type + mimeType := blob.MIMEType + if mimeType == "" { + return nil + } + + // Handle images + if isImageMimeType(mimeType) { + // Convert to base64 data URL + imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64.StdEncoding.EncodeToString(blob.Data)) + return &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeImage, + ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{ + ImageURL: &imageURL, + }, + } + } + + // Handle audio + if strings.HasPrefix(mimeType, "audio/") { + encodedData := base64.StdEncoding.EncodeToString(blob.Data) + format := mimeType + if strings.HasPrefix(mimeType, "audio/") { + format = mimeType[6:] // Remove "audio/" prefix + } + + return &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeAudio, + Audio: &schemas.ResponsesInputMessageContentBlockAudio{ + Format: format, + Data: encodedData, + }, + } + } + + // Handle other files + encodedData := base64.StdEncoding.EncodeToString(blob.Data) + return &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeFile, + ResponsesInputMessageContentBlockFile: &schemas.ResponsesInputMessageContentBlockFile{ + FileData: &encodedData, + Filename: &blob.DisplayName, + }, + } +} + +// convertGeminiFileDataToContentBlock converts Gemini file data (URI) to content block +func convertGeminiFileDataToContentBlock(fileData *FileData) *schemas.ResponsesMessageContentBlock { + if fileData == nil || fileData.FileURI == "" { + return nil + } + + mimeType := fileData.MIMEType + if mimeType == "" { + mimeType = "application/octet-stream" + } + + // Handle images + if isImageMimeType(mimeType) { + return &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeImage, + ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{ + ImageURL: &fileData.FileURI, + }, + } + } + + // Handle other files + return &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeFile, + ResponsesInputMessageContentBlockFile: &schemas.ResponsesInputMessageContentBlockFile{ + FileURL: &fileData.FileURI, + }, + } +} + +func convertGeminiToolsToResponsesTools(tools []Tool) []schemas.ResponsesTool { + var responsesTools []schemas.ResponsesTool + + for _, tool := range tools { + if len(tool.FunctionDeclarations) > 0 { + for _, fn := range tool.FunctionDeclarations { + responsesTool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeFunction, + Name: schemas.Ptr(fn.Name), + Description: schemas.Ptr(fn.Description), + ResponsesToolFunction: &schemas.ResponsesToolFunction{}, + } + // Convert parameters schema if present + if fn.Parameters != nil { + params := convertSchemaToFunctionParameters(fn.Parameters) + responsesTool.ResponsesToolFunction.Parameters = ¶ms + } + responsesTools = append(responsesTools, responsesTool) + } + } + } + + return responsesTools +} + +func convertGeminiToolConfigToToolChoice(toolConfig ToolConfig) *schemas.ResponsesToolChoice { + if toolConfig.FunctionCallingConfig == nil { + return nil + } + + toolChoice := &schemas.ResponsesToolChoiceStruct{ + Type: schemas.ResponsesToolChoiceTypeFunction, + } + + switch toolConfig.FunctionCallingConfig.Mode { + case FunctionCallingConfigModeAuto: + toolChoice.Mode = schemas.Ptr("auto") + case FunctionCallingConfigModeNone: + toolChoice.Mode = schemas.Ptr("none") + default: + toolChoice.Mode = schemas.Ptr("auto") + } + + if toolConfig.FunctionCallingConfig.AllowedFunctionNames != nil { + for _, functionName := range toolConfig.FunctionCallingConfig.AllowedFunctionNames { + toolChoice.Tools = append(toolChoice.Tools, schemas.ResponsesToolChoiceAllowedToolDef{ + Type: string(schemas.ResponsesToolTypeFunction), + Name: schemas.Ptr(functionName), + }) + } + } + + return &schemas.ResponsesToolChoice{ + ResponsesToolChoiceStruct: toolChoice, + } +} + // Helper functions for Responses conversion // convertGeminiCandidatesToResponsesOutput converts Gemini candidates to Responses output messages func convertGeminiCandidatesToResponsesOutput(candidates []*Candidate) []schemas.ResponsesMessage { @@ -145,19 +1702,19 @@ func convertGeminiCandidatesToResponsesOutput(candidates []*Candidate) []schemas } } - // Create copies of the values to avoid range loop variable capture - functionCallID := part.FunctionCall.ID - functionCallName := part.FunctionCall.Name + callID := part.FunctionCall.ID + if strings.TrimSpace(callID) == "" { + callID = part.FunctionCall.Name + } + name := part.FunctionCall.Name toolMsg := &schemas.ResponsesToolMessage{ - CallID: &functionCallID, - Name: &functionCallName, + CallID: &callID, + Name: &name, Arguments: &argumentsStr, } - msg := schemas.ResponsesMessage{ Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{}, Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), ResponsesToolMessage: toolMsg, } @@ -180,28 +1737,16 @@ func convertGeminiCandidatesToResponsesOutput(candidates []*Candidate) []schemas case part.FunctionResponse != nil: // Function response message - output := "" - if part.FunctionResponse.Response != nil { - if outputVal, ok := part.FunctionResponse.Response["output"]; ok { - if outputStr, ok := outputVal.(string); ok { - output = outputStr - } - } - } + output := extractFunctionResponseOutput(part.FunctionResponse) msg := schemas.ResponsesMessage{ Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: &output, - }, - }, - }, Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), ResponsesToolMessage: &schemas.ResponsesToolMessage{ CallID: schemas.Ptr(part.FunctionResponse.ID), + Output: &schemas.ResponsesToolMessageOutputStruct{ + ResponsesToolCallOutputStr: &output, + }, }, } @@ -318,6 +1863,18 @@ func convertGeminiCandidatesToResponsesOutput(candidates []*Candidate) []schemas Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), } messages = append(messages, msg) + case part.ThoughtSignature != nil: + // Handle thought signature + thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature) + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + EncryptedContent: &thoughtSig, + }, + } + messages = append(messages, msg) } } } @@ -326,7 +1883,7 @@ func convertGeminiCandidatesToResponsesOutput(candidates []*Candidate) []schemas } // convertParamsToGenerationConfigResponses converts ChatParameters to GenerationConfig for Responses -func convertParamsToGenerationConfigResponses(params *schemas.ResponsesParameters) GenerationConfig { +func (r *GeminiGenerationRequest) convertParamsToGenerationConfigResponses(params *schemas.ResponsesParameters) GenerationConfig { config := GenerationConfig{} if params.Temperature != nil { @@ -338,6 +1895,22 @@ func convertParamsToGenerationConfigResponses(params *schemas.ResponsesParameter if params.MaxOutputTokens != nil { config.MaxOutputTokens = int32(*params.MaxOutputTokens) } + if params.Reasoning != nil { + config.ThinkingConfig = &GenerationConfigThinkingConfig{ + IncludeThoughts: true, + } + if params.Reasoning.Effort != nil { + switch *params.Reasoning.Effort { + case "minimal", "low": + config.ThinkingConfig.ThinkingLevel = ThinkingLevelLow + case "medium", "high": + config.ThinkingConfig.ThinkingLevel = ThinkingLevelHigh + } + } + if params.Reasoning.MaxTokens != nil { + config.ThinkingConfig.ThinkingBudget = schemas.Ptr(int32(*params.Reasoning.MaxTokens)) + } + } if params.ExtraParams != nil { if topK, ok := params.ExtraParams["top_k"]; ok { @@ -567,11 +2140,17 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag content := Content{} if msg.Role != nil { - content.Role = string(*msg.Role) - } else { - content.Role = "user" // Default role if msg.Role is nil + // Map Responses roles to Gemini roles (Gemini only supports "user" and "model") + switch *msg.Role { + case schemas.ResponsesInputMessageRoleAssistant: + content.Role = "model" + case schemas.ResponsesInputMessageRoleUser, schemas.ResponsesInputMessageRoleDeveloper: + content.Role = "user" + default: + // Default to "user" for input messages (any instructions/context) + content.Role = "user" + } } - // Convert message content if msg.Content != nil { if msg.Content.ContentStr != nil { @@ -668,13 +2247,30 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag // convertContentBlockToGeminiPart converts a content block to Gemini part func convertContentBlockToGeminiPart(block schemas.ResponsesMessageContentBlock) (*Part, error) { switch block.Type { - case schemas.ResponsesInputMessageContentBlockTypeText: - if block.Text != nil { + case schemas.ResponsesInputMessageContentBlockTypeText, + schemas.ResponsesOutputMessageContentTypeText: + if block.Text != nil && *block.Text != "" { return &Part{ Text: *block.Text, }, nil } + case schemas.ResponsesOutputMessageContentTypeReasoning: + if block.Text != nil && *block.Text != "" { + return &Part{ + Text: *block.Text, + Thought: true, + }, nil + } + + case schemas.ResponsesOutputMessageContentTypeRefusal: + // Refusals are treated as regular text in Gemini + if block.ResponsesOutputMessageContentRefusal != nil { + return &Part{ + Text: block.ResponsesOutputMessageContentRefusal.Refusal, + }, nil + } + case schemas.ResponsesInputMessageContentBlockTypeImage: if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { imageURL := *block.ResponsesInputMessageContentBlockImage.ImageURL @@ -754,10 +2350,16 @@ func convertContentBlockToGeminiPart(block schemas.ResponsesMessageContentBlock) }, }, nil } else if block.ResponsesInputMessageContentBlockFile.FileData != nil { + raw := *block.ResponsesInputMessageContentBlockFile.FileData + data := []byte(raw) + // FileData is base64-encoded + if decoded, err := base64.StdEncoding.DecodeString(raw); err == nil { + data = decoded + } return &Part{ InlineData: &Blob{ MIMEType: "application/octet-stream", // default - Data: []byte(*block.ResponsesInputMessageContentBlockFile.FileData), + Data: data, }, }, nil } diff --git a/core/providers/gemini/speech.go b/core/providers/gemini/speech.go index d1d7b40be..ad38e10d3 100644 --- a/core/providers/gemini/speech.go +++ b/core/providers/gemini/speech.go @@ -13,13 +13,6 @@ import ( func (request *GeminiGenerationRequest) ToBifrostSpeechRequest() *schemas.BifrostSpeechRequest { provider, model := schemas.ParseModelString(request.Model, schemas.Gemini) - if provider == schemas.Vertex { - // Add google/ prefix for Bifrost if not already present - if !strings.HasPrefix(model, "google/") { - model = "google/" + model - } - } - bifrostReq := &schemas.BifrostSpeechRequest{ Provider: provider, Model: model, @@ -107,7 +100,7 @@ func ToGeminiSpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) (*GeminiGen // Convert parameters to generation config geminiReq.GenerationConfig.ResponseModalities = []Modality{ModalityAudio} // Convert speech input to Gemini format - if bifrostReq.Input.Input != "" { + if bifrostReq.Input != nil && bifrostReq.Input.Input != "" { geminiReq.Contents = []Content{ { Parts: []*Part{ @@ -156,7 +149,7 @@ func (response *GenerateContentResponse) ToBifrostSpeechResponse(ctx context.Con return nil, fmt.Errorf("failed to convert PCM to WAV: %v", err) } bifrostResp.Audio = wavData - }else{ + } else { bifrostResp.Audio = audioData } } diff --git a/core/providers/gemini/transcription.go b/core/providers/gemini/transcription.go index 6677483cf..399ac4df4 100644 --- a/core/providers/gemini/transcription.go +++ b/core/providers/gemini/transcription.go @@ -10,13 +10,6 @@ import ( func (request *GeminiGenerationRequest) ToBifrostTranscriptionRequest() *schemas.BifrostTranscriptionRequest { provider, model := schemas.ParseModelString(request.Model, schemas.Gemini) - if provider == schemas.Vertex { - // Add google/ prefix for Bifrost if not already present - if !strings.HasPrefix(model, "google/") { - model = "google/" + model - } - } - bifrostReq := &schemas.BifrostTranscriptionRequest{ Provider: provider, Model: model, diff --git a/core/providers/gemini/types.go b/core/providers/gemini/types.go index 8a4aab44f..a57b7fa92 100644 --- a/core/providers/gemini/types.go +++ b/core/providers/gemini/types.go @@ -907,8 +907,19 @@ type GenerationConfigThinkingConfig struct { IncludeThoughts bool `json:"includeThoughts,omitempty"` // Optional. Indicates the thinking budget in tokens. ThinkingBudget *int32 `json:"thinkingBudget,omitempty"` + + // Optional. Indicates the thinking level. + ThinkingLevel ThinkingLevel `json:"thinkingLevel,omitempty"` } +type ThinkingLevel string + +const ( + ThinkingLevelUnspecified ThinkingLevel = "THINKING_LEVEL_UNSPECIFIED" + ThinkingLevelLow ThinkingLevel = "LOW" + ThinkingLevelHigh ThinkingLevel = "HIGH" +) + // GeminiEmbeddingRequest represents a single embedding request in a batch. type GeminiEmbeddingRequest struct { Content *Content `json:"content,omitempty"` diff --git a/core/providers/gemini/utils.go b/core/providers/gemini/utils.go index 3d1c9cf5c..1ea918ee8 100644 --- a/core/providers/gemini/utils.go +++ b/core/providers/gemini/utils.go @@ -8,55 +8,67 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -// convertGenerationConfigToChatParameters converts Gemini GenerationConfig to ChatParameters -func (r *GeminiGenerationRequest) convertGenerationConfigToChatParameters() *schemas.ChatParameters { - params := &schemas.ChatParameters{ +func (r *GeminiGenerationRequest) convertGenerationConfigToResponsesParameters() *schemas.ResponsesParameters { + params := &schemas.ResponsesParameters{ ExtraParams: make(map[string]interface{}), } config := r.GenerationConfig - // Map generation config fields to parameters if config.Temperature != nil { params.Temperature = config.Temperature } if config.TopP != nil { params.TopP = config.TopP } + if config.Logprobs != nil { + params.TopLogProbs = schemas.Ptr(int(*config.Logprobs)) + } if config.TopK != nil { params.ExtraParams["top_k"] = *config.TopK } if config.MaxOutputTokens > 0 { - params.MaxCompletionTokens = schemas.Ptr(int(config.MaxOutputTokens)) + params.MaxOutputTokens = schemas.Ptr(int(config.MaxOutputTokens)) + } + if config.ThinkingConfig != nil { + params.Reasoning = &schemas.ResponsesParametersReasoning{} + if config.ThinkingConfig.ThinkingBudget != nil { + params.Reasoning.MaxTokens = schemas.Ptr(int(*config.ThinkingConfig.ThinkingBudget)) + } + if config.ThinkingConfig.ThinkingLevel != ThinkingLevelUnspecified { + switch config.ThinkingConfig.ThinkingLevel { + case ThinkingLevelLow: + params.Reasoning.Effort = schemas.Ptr("low") + case ThinkingLevelHigh: + params.Reasoning.Effort = schemas.Ptr("high") + } + } } if config.CandidateCount > 0 { params.ExtraParams["candidate_count"] = config.CandidateCount } if len(config.StopSequences) > 0 { - params.Stop = config.StopSequences + params.ExtraParams["stop_sequences"] = config.StopSequences } if config.PresencePenalty != nil { - params.PresencePenalty = config.PresencePenalty + params.ExtraParams["presence_penalty"] = config.PresencePenalty } if config.FrequencyPenalty != nil { - params.FrequencyPenalty = config.FrequencyPenalty + params.ExtraParams["frequency_penalty"] = config.FrequencyPenalty } if config.Seed != nil { - params.Seed = schemas.Ptr(int(*config.Seed)) + params.ExtraParams["seed"] = int(*config.Seed) } if config.ResponseMIMEType != "" { - params.ExtraParams["response_mime_type"] = config.ResponseMIMEType - - // Convert Gemini's response format to OpenAI's response_format for compatibility switch config.ResponseMIMEType { case "application/json": - params.ResponseFormat = buildOpenAIResponseFormat(config.ResponseSchema, config.ResponseJSONSchema) + params.Text = buildOpenAIResponseFormat(config.ResponseSchema, config.ResponseJSONSchema) case "text/plain": - // Gemini text/plain → OpenAI text format - var responseFormat interface{} = map[string]interface{}{ - "type": "text", + params.Text = &schemas.ResponsesTextConfig{ + Format: &schemas.ResponsesTextConfigFormat{ + Type: "text", + }, } - params.ResponseFormat = &responseFormat } } if config.ResponseSchema != nil { @@ -68,15 +80,11 @@ func (r *GeminiGenerationRequest) convertGenerationConfigToChatParameters() *sch if config.ResponseLogprobs { params.ExtraParams["response_logprobs"] = config.ResponseLogprobs } - if config.Logprobs != nil { - params.ExtraParams["logprobs"] = *config.Logprobs - } - return params } // convertSchemaToFunctionParameters converts genai.Schema to schemas.FunctionParameters -func (r *GeminiGenerationRequest) convertSchemaToFunctionParameters(schema *Schema) schemas.ToolFunctionParameters { +func convertSchemaToFunctionParameters(schema *Schema) schemas.ToolFunctionParameters { params := schemas.ToolFunctionParameters{ Type: strings.ToLower(string(schema.Type)), } @@ -197,16 +205,30 @@ func isImageMimeType(mimeType string) bool { return false } -// ensureExtraParams ensures that bifrostReq.Params and bifrostReq.Params.ExtraParams are initialized -func ensureExtraParams(bifrostReq *schemas.BifrostChatRequest) { - if bifrostReq.Params == nil { - bifrostReq.Params = &schemas.ChatParameters{ - ExtraParams: make(map[string]interface{}), - } +var ( + // Maps Gemini finish reasons to Bifrost format + geminiFinishReasonToBifrost = map[FinishReason]string{ + FinishReasonStop: "stop", + FinishReasonMaxTokens: "length", + FinishReasonSafety: "content_filter", + FinishReasonRecitation: "content_filter", + FinishReasonLanguage: "content_filter", + FinishReasonOther: "stop", + FinishReasonBlocklist: "content_filter", + FinishReasonProhibitedContent: "content_filter", + FinishReasonSPII: "content_filter", + FinishReasonMalformedFunctionCall: "tool_calls", + FinishReasonImageSafety: "content_filter", + FinishReasonUnexpectedToolCall: "tool_calls", } - if bifrostReq.Params.ExtraParams == nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) +) + +// ConvertGeminiFinishReasonToBifrost converts Gemini finish reasons to Bifrost format +func ConvertGeminiFinishReasonToBifrost(providerReason FinishReason) string { + if bifrostReason, ok := geminiFinishReasonToBifrost[providerReason]; ok { + return bifrostReason } + return string(providerReason) } // extractUsageMetadata extracts usage metadata from the Gemini response @@ -222,6 +244,72 @@ func (r *GenerateContentResponse) extractUsageMetadata() (int, int, int, int, in return inputTokens, outputTokens, totalTokens, cachedTokens, reasoningTokens } +// convertGeminiUsageMetadataToChatUsage converts Gemini usage metadata to Bifrost chat LLM usage +func convertGeminiUsageMetadataToChatUsage(metadata *GenerateContentResponseUsageMetadata) *schemas.BifrostLLMUsage { + if metadata == nil { + return nil + } + + usage := &schemas.BifrostLLMUsage{ + PromptTokens: int(metadata.PromptTokenCount), + CompletionTokens: int(metadata.CandidatesTokenCount), + TotalTokens: int(metadata.TotalTokenCount), + } + + // Add cached tokens if present + if metadata.CachedContentTokenCount > 0 { + usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + CachedTokens: int(metadata.CachedContentTokenCount), + } + } + + // Add reasoning tokens if present + if metadata.ThoughtsTokenCount > 0 { + usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + ReasoningTokens: int(metadata.ThoughtsTokenCount), + } + } + + return usage +} + +// convertGeminiUsageMetadataToResponsesUsage converts Gemini usage metadata to Bifrost responses usage +func convertGeminiUsageMetadataToResponsesUsage(metadata *GenerateContentResponseUsageMetadata) *schemas.ResponsesResponseUsage { + if metadata == nil { + return nil + } + + usage := &schemas.ResponsesResponseUsage{ + TotalTokens: int(metadata.TotalTokenCount), + InputTokens: int(metadata.PromptTokenCount), + OutputTokens: int(metadata.CandidatesTokenCount), + OutputTokensDetails: &schemas.ResponsesResponseOutputTokens{}, + InputTokensDetails: &schemas.ResponsesResponseInputTokens{}, + } + + // Add cached tokens if present + if metadata.CachedContentTokenCount > 0 { + usage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ + CachedTokens: int(metadata.CachedContentTokenCount), + } + } + + if metadata.CandidatesTokensDetails != nil { + for _, detail := range metadata.CandidatesTokensDetails { + switch detail.Modality { + case "AUDIO": + usage.OutputTokensDetails.AudioTokens = int(detail.TokenCount) + } + } + } + + if metadata.ThoughtsTokenCount > 0 { + usage.OutputTokensDetails.ReasoningTokens = int(metadata.ThoughtsTokenCount) + } + + return usage +} + // convertParamsToGenerationConfig converts Bifrost parameters to Gemini GenerationConfig func convertParamsToGenerationConfig(params *schemas.ChatParameters, responseModalities []string) GenerationConfig { config := GenerationConfig{} @@ -258,6 +346,22 @@ func convertParamsToGenerationConfig(params *schemas.ChatParameters, responseMod penalty := float64(*params.FrequencyPenalty) config.FrequencyPenalty = &penalty } + if params.Reasoning != nil { + config.ThinkingConfig = &GenerationConfigThinkingConfig{ + IncludeThoughts: true, + } + if params.Reasoning.MaxTokens != nil { + config.ThinkingConfig.ThinkingBudget = schemas.Ptr(int32(*params.Reasoning.MaxTokens)) + } + if params.Reasoning.Effort != nil { + switch *params.Reasoning.Effort { + case "minimal", "low": + config.ThinkingConfig.ThinkingLevel = ThinkingLevelLow + case "medium", "high": + config.ThinkingConfig.ThinkingLevel = ThinkingLevelHigh + } + } + } // Handle response_format to response_schema conversion if params.ResponseFormat != nil { @@ -766,7 +870,7 @@ func convertGeminiTypeToJSONSchemaType(geminiType string) string { } // buildOpenAIResponseFormat builds OpenAI response_format for JSON types -func buildOpenAIResponseFormat(responseSchema *Schema, responseJsonSchema interface{}) *interface{} { +func buildOpenAIResponseFormat(responseSchema *Schema, responseJsonSchema interface{}) *schemas.ResponsesTextConfig { var schema interface{} name := "response_schema" @@ -800,22 +904,23 @@ func buildOpenAIResponseFormat(responseSchema *Schema, responseJsonSchema interf } } else { // No schema provided - use older json_object mode - var format interface{} = map[string]interface{}{ - "type": "json_object", + return &schemas.ResponsesTextConfig{ + Format: &schemas.ResponsesTextConfigFormat{ + Type: "json_object", + }, } - return &format } - // Has schema - use json_schema mode (Structured Outputs) - var format interface{} = map[string]interface{}{ - "type": "json_schema", - "json_schema": map[string]interface{}{ - "name": name, - "strict": false, - "schema": schema, + return &schemas.ResponsesTextConfig{ + Format: &schemas.ResponsesTextConfigFormat{ + Type: "json_schema", + JSONSchema: &schemas.ResponsesTextConfigFormatJSONSchema{ + Name: schemas.Ptr(name), + Strict: schemas.Ptr(false), + Schema: schemas.Ptr(schema), + }, }, } - return &format } // extractSchemaFromResponseFormat extracts Gemini Schema from OpenAI's response_format structure @@ -858,3 +963,26 @@ func extractSchemaFromResponseFormat(responseFormat *interface{}) *Schema { return schema } + +// extractFunctionResponseOutput extracts the output text from a FunctionResponse. +// It first tries to extract the "output" field if present, otherwise marshals the entire response. +// Returns an empty string if the response is nil or extraction fails. +func extractFunctionResponseOutput(funcResp *FunctionResponse) string { + if funcResp == nil || funcResp.Response == nil { + return "" + } + + // Try to extract "output" field first + if outputVal, ok := funcResp.Response["output"]; ok { + if outputStr, ok := outputVal.(string); ok { + return outputStr + } + } + + // If no "output" key, marshal the entire response + if jsonResponse, err := sonic.Marshal(funcResp.Response); err == nil { + return string(jsonResponse) + } + + return "" +} diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 636fe8648..8bde52de0 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -17,6 +17,7 @@ import ( "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/providers/anthropic" + "github.com/maximhq/bifrost/core/providers/gemini" "github.com/maximhq/bifrost/core/providers/openai" providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" @@ -279,6 +280,10 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. } deployment := provider.getModelDeployment(key, request.Model) + // strip google/ prefix if present + if after, ok := strings.CutPrefix(deployment, "google/"); ok { + deployment = after + } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -306,6 +311,22 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. if err := sonic.Unmarshal(reqBytes, &requestBody); err != nil { return nil, fmt.Errorf("failed to unmarshal request body: %w", err) } + } else if schemas.IsGeminiModel(deployment) { + reqBody := gemini.ToGeminiChatCompletionRequest(request) + if reqBody == nil { + return nil, fmt.Errorf("chat completion input is not provided") + } + reqBody.Model = deployment + // Strip unsupported fields for Vertex Gemini + stripVertexGeminiUnsupportedFields(reqBody) + // Convert struct to map for Vertex API + reqBytes, err := sonic.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + if err := sonic.Unmarshal(reqBytes, &requestBody); err != nil { + return nil, fmt.Errorf("failed to unmarshal request body: %w", err) + } } else { // Use centralized OpenAI converter for non-Claude models reqBody := openai.ToOpenAIChatRequest(request) @@ -379,6 +400,12 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. } else { completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, deployment) } + } else if schemas.IsGeminiModel(deployment) { + if region == "global" { + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + } else { + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + } } else { // Other models use OpenAPI endpoint for gemini models if key.Value != "" { @@ -470,6 +497,32 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. response.ExtraFields.RawResponse = rawResponse } + return response, nil + } else if schemas.IsGeminiModel(deployment) { + geminiResponse := gemini.GenerateContentResponse{} + + rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &geminiResponse, jsonBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response := geminiResponse.ToBifrostChatResponse() + response.ExtraFields.RequestType = schemas.ChatCompletionRequest + response.ExtraFields.Provider = providerName + response.ExtraFields.ModelRequested = request.Model + if request.Model != deployment { + response.ExtraFields.ModelDeployment = deployment + } + response.ExtraFields.Latency = latency.Milliseconds() + + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + response.ExtraFields.RawRequest = rawRequest + } + + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + response.ExtraFields.RawResponse = rawResponse + } + return response, nil } else { response := &schemas.BifrostChatResponse{} @@ -522,6 +575,10 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo } deployment := provider.getModelDeployment(key, request.Model) + // strip google/ prefix if present + if after, ok := strings.CutPrefix(deployment, "google/"); ok { + deployment = after + } postResponseConverter := func(response *schemas.BifrostChatResponse) *schemas.BifrostChatResponse { response.ExtraFields.ModelRequested = request.Model @@ -609,6 +666,81 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo postResponseConverter, provider.logger, ) + } else if schemas.IsGeminiModel(deployment) { + // Use Gemini-style streaming for Gemini models + jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( + ctx, + request, + func() (any, error) { + reqBody := gemini.ToGeminiChatCompletionRequest(request) + if reqBody == nil { + return nil, fmt.Errorf("chat completion input is not provided") + } + reqBody.Model = deployment + // Strip unsupported fields for Vertex Gemini + stripVertexGeminiUnsupportedFields(reqBody) + return reqBody, nil + }, + provider.GetProviderKey()) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Auth query is used to pass the API key in the query string + authQuery := "" + if key.Value != "" { + authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value)) + } + + // Construct the URL for Gemini streaming + var completeURL string + if region == "global" { + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:streamGenerateContent", projectID, deployment) + } else { + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:streamGenerateContent", region, projectID, region, deployment) + } + + // Add alt=sse parameter + if authQuery != "" { + completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery) + } else { + completeURL = fmt.Sprintf("%s?alt=sse", completeURL) + } + + // Prepare headers for Vertex Gemini + headers := map[string]string{ + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // If no auth query, use OAuth2 token + if authQuery == "" { + tokenSource, err := getAuthTokenSource(key) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + } + token, err := tokenSource.Token() + if err != nil { + return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + } + headers["Authorization"] = "Bearer " + token.AccessToken + } + + // Use shared streaming logic from Gemini + return gemini.HandleGeminiChatCompletionStream( + ctx, + provider.client, + completeURL, + jsonData, + headers, + provider.networkConfig.ExtraHeaders, + providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + provider.GetProviderKey(), + request.Model, + postHookRunner, + postResponseConverter, + provider.logger, + ) } else { var authHeader map[string]string // Auth query is used for fine-tuned models to pass the API key in the query string @@ -699,6 +831,10 @@ func (provider *VertexProvider) Responses(ctx context.Context, key schemas.Key, } deployment := provider.getModelDeployment(key, request.Model) + // strip google/ prefix if present + if after, ok := strings.CutPrefix(deployment, "google/"); ok { + deployment = after + } if schemas.IsAnthropicModel(deployment) { jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false) @@ -719,9 +855,9 @@ func (provider *VertexProvider) Responses(ctx context.Context, key schemas.Key, // Claude models use Anthropic publisher var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) } // Create HTTP request for streaming @@ -796,6 +932,106 @@ func (provider *VertexProvider) Responses(ctx context.Context, key schemas.Key, response.ExtraFields.ModelDeployment = deployment } + return response, nil + } else if schemas.IsGeminiModel(deployment) { + jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( + ctx, + request, + func() (any, error) { + reqBody := gemini.ToGeminiResponsesRequest(request) + if reqBody == nil { + return nil, fmt.Errorf("responses input is not provided") + } + reqBody.Model = deployment + // Strip unsupported fields for Vertex Gemini + stripVertexGeminiUnsupportedFields(reqBody) + return reqBody, nil + }, + provider.GetProviderKey()) + if bifrostErr != nil { + return nil, bifrostErr + } + + projectID := key.VertexKeyConfig.ProjectID + if projectID == "" { + return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + } + + region := key.VertexKeyConfig.Region + if region == "" { + return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + } + + var url string + if region == "global" { + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + } else { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + } + + // Create HTTP request for streaming + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + + // Getting oauth2 token + tokenSource, err := getAuthTokenSource(key) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + } + token, err := tokenSource.Token() + if err != nil { + return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + } + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + + req.SetRequestURI(url) + req.SetBody(jsonBody) + + // Make the request + latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + if resp.StatusCode() != fasthttp.StatusOK { + // Remove client from pool for authentication/authorization errors + if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + } + return nil, parseVertexError(providerName, resp) + } + + geminiResponse := &gemini.GenerateContentResponse{} + + rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), geminiResponse, jsonBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response := geminiResponse.ToResponsesBifrostResponsesResponse() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = providerName + response.ExtraFields.ModelRequested = request.Model + response.ExtraFields.Latency = latency.Milliseconds() + + if request.Model != deployment { + response.ExtraFields.ModelDeployment = deployment + } + + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + response.ExtraFields.RawResponse = rawResponse + } + + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + response.ExtraFields.RawRequest = rawRequest + } + return response, nil } else { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) @@ -826,6 +1062,10 @@ func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRun } deployment := provider.getModelDeployment(key, request.Model) + // strip google/ prefix if present + if after, ok := strings.CutPrefix(deployment, "google/"); ok { + deployment = after + } if schemas.IsAnthropicModel(deployment) { region := key.VertexKeyConfig.Region @@ -891,6 +1131,99 @@ func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRun postResponseConverter, provider.logger, ) + } else if schemas.IsGeminiModel(deployment) { + region := key.VertexKeyConfig.Region + if region == "" { + return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + } + + projectID := key.VertexKeyConfig.ProjectID + if projectID == "" { + return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + } + + // Use Gemini-style streaming for Gemini models + jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( + ctx, + request, + func() (any, error) { + reqBody := gemini.ToGeminiResponsesRequest(request) + if reqBody == nil { + return nil, fmt.Errorf("responses input is not provided") + } + reqBody.Model = deployment + // Strip unsupported fields for Vertex Gemini + stripVertexGeminiUnsupportedFields(reqBody) + return reqBody, nil + }, + provider.GetProviderKey()) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Auth query is used to pass the API key in the query string + authQuery := "" + if key.Value != "" { + authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value)) + } + + // Construct the URL for Gemini streaming + var completeURL string + if region == "global" { + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:streamGenerateContent", projectID, deployment) + } else { + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:streamGenerateContent", region, projectID, region, deployment) + } + + // Add alt=sse parameter + if authQuery != "" { + completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery) + } else { + completeURL = fmt.Sprintf("%s?alt=sse", completeURL) + } + + // Prepare headers for Vertex Gemini + headers := map[string]string{ + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // If no auth query, use OAuth2 token + if authQuery == "" { + tokenSource, err := getAuthTokenSource(key) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + } + token, err := tokenSource.Token() + if err != nil { + return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + } + headers["Authorization"] = "Bearer " + token.AccessToken + } + + postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { + response.ExtraFields.ModelRequested = request.Model + if request.Model != deployment { + response.ExtraFields.ModelDeployment = deployment + } + return response + } + + // Use shared streaming logic from Gemini + return gemini.HandleGeminiResponsesStream( + ctx, + provider.client, + completeURL, + jsonData, + headers, + provider.networkConfig.ExtraHeaders, + providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + provider.GetProviderKey(), + request.Model, + postHookRunner, + postResponseConverter, + provider.logger, + ) } else { ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( @@ -1056,6 +1389,24 @@ func (provider *VertexProvider) TranscriptionStream(ctx context.Context, postHoo return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } +// stripVertexGeminiUnsupportedFields removes fields that are not supported by Vertex AI's Gemini API. +// Specifically, it removes the "id" field from function_call and function_response objects in contents. +func stripVertexGeminiUnsupportedFields(requestBody *gemini.GeminiGenerationRequest) { + for _, content := range requestBody.Contents { + for _, part := range content.Parts { + // Remove id from function_call + if part.FunctionCall != nil { + part.FunctionCall.ID = "" + } + + // Remove id from function_response + if part.FunctionResponse != nil { + part.FunctionResponse.ID = "" + } + } + } +} + func (provider *VertexProvider) getModelDeployment(key schemas.Key, model string) string { if key.VertexKeyConfig == nil { return model diff --git a/core/providers/vertex/vertex_test.go b/core/providers/vertex/vertex_test.go index 0ce8cc64e..cd9306e55 100644 --- a/core/providers/vertex/vertex_test.go +++ b/core/providers/vertex/vertex_test.go @@ -25,7 +25,7 @@ func TestVertex(t *testing.T) { testConfig := testutil.ComprehensiveTestConfig{ Provider: schemas.Vertex, ChatModel: "google/gemini-2.0-flash-001", - VisionModel: "google/gemini-2.0-flash-001", + VisionModel: "gemini-2.0-flash-001", TextModel: "", // Vertex doesn't support text completion in newer models EmbeddingModel: "text-multilingual-embedding-002", ReasoningModel: "claude-4.5-haiku", @@ -36,7 +36,7 @@ func TestVertex(t *testing.T) { MultiTurnConversation: true, ToolCalls: true, ToolCallsStreaming: true, - MultipleToolCalls: true, + MultipleToolCalls: false, // multiple tool calls supported on gemini endpoint only if all tools are search tools End2EndToolCalling: true, AutomaticFunctionCall: true, ImageURL: true, diff --git a/core/schemas/responses.go b/core/schemas/responses.go index 3df6d9618..5b974fcaa 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -131,6 +131,10 @@ type ResponsesTextConfigFormat struct { // ResponsesTextConfigFormatJSONSchema represents a JSON schema specification type ResponsesTextConfigFormatJSONSchema struct { + Name *string `json:"name,omitempty"` + Schema *any `json:"schema,omitempty"` + Description *string `json:"description,omitempty"` + Strict *bool `json:"strict,omitempty"` AdditionalProperties *bool `json:"additionalProperties,omitempty"` Properties *map[string]any `json:"properties,omitempty"` Required []string `json:"required,omitempty"` diff --git a/core/schemas/utils.go b/core/schemas/utils.go index a450b3ed5..fbad2c13a 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -1049,6 +1049,10 @@ func IsMistralModel(model string) bool { return strings.Contains(model, "mistral") || strings.Contains(model, "codestral") } +func IsGeminiModel(model string) bool { + return strings.Contains(model, "gemini") +} + // Precompiled regexes for different kinds of version suffixes. var ( // Anthropic-style date: 20250514 diff --git a/tests/integrations/config.yml b/tests/integrations/config.yml index 1b51e7a25..043e7d723 100644 --- a/tests/integrations/config.yml +++ b/tests/integrations/config.yml @@ -60,13 +60,13 @@ providers: - "claude-3-haiku-20240307" gemini: - chat: "gemini-2.0-flash" - vision: "gemini-2.0-flash" - tools: "gemini-2.0-flash" + chat: "gemini-2.5-flash" + vision: "gemini-2.5-flash" + tools: "gemini-2.5-flash" speech: "gemini-2.5-flash-preview-tts" transcription: "gemini-2.5-flash" embeddings: "gemini-embedding-001" - streaming: "gemini-2.0-flash" + streaming: "gemini-2.5-flash" alternatives: - "gemini-1.5-pro" - "gemini-1.5-flash" diff --git a/tests/integrations/tests/test_google.py b/tests/integrations/tests/test_google.py index c5ca60ae1..4798e6510 100644 --- a/tests/integrations/tests/test_google.py +++ b/tests/integrations/tests/test_google.py @@ -34,6 +34,8 @@ 24. Speech generation - different voices 25. Speech generation - language support 26. Speech generation - streaming (if supported) +27. Extended thinking/reasoning (non-streaming) +28. Extended thinking/reasoning (streaming) """ import pytest @@ -81,6 +83,9 @@ GENAI_INVALID_ROLE_CONTENT, EMBEDDINGS_SINGLE_TEXT, SPEECH_TEST_INPUT, + # Gemini-specific test data + GEMINI_REASONING_PROMPT, + GEMINI_REASONING_STREAMING_PROMPT, ) from .utils.config_loader import get_model from .utils.parametrize import ( @@ -506,15 +511,14 @@ def test_11_integration_specific_features(self, google_client, test_config): @skip_if_no_api_key("google") def test_12_error_handling_invalid_roles(self, google_client, test_config): """Test Case 12: Error handling for invalid roles""" - with pytest.raises(Exception) as exc_info: - google_client.models.generate_content( - model=get_model("google", "chat"), contents=GENAI_INVALID_ROLE_CONTENT - ) + response = google_client.models.generate_content( + model=get_model("google", "chat"), contents=GENAI_INVALID_ROLE_CONTENT + ) - # Verify the error is properly caught and contains role-related information - error = exc_info.value - assert_valid_error_response(error, "tester") - assert_error_propagation(error, "google") + # Verify the response is successful + assert response is not None + assert hasattr(response, "candidates") + assert len(response.candidates) > 0 @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("streaming")) def test_13_streaming(self, google_client, test_config, provider, model): @@ -882,6 +886,172 @@ def test_25_speech_generation_language_support(self, google_client, test_config, wav_audio = convert_pcm_to_wav(audio_data) assert_valid_speech_response(wav_audio, expected_audio_size_min=1000) + @skip_if_no_api_key("google") + def test_26_extended_thinking(self, google_client, test_config): + """Test Case 26: Extended thinking/reasoning (non-streaming)""" + from google.genai import types + + # Convert to Google GenAI message format + messages = GEMINI_REASONING_PROMPT[0]["content"] + + # Use a thinking-capable model (Gemini 2.0+ supports thinking) + response = google_client.models.generate_content( + model=get_model("google", "chat"), + contents=messages, + config=types.GenerateContentConfig( + thinking_config=types.ThinkingConfig( + include_thoughts=True, + thinking_budget=5000, + ), + max_output_tokens=800, + ), + ) + + # Validate response structure + assert response is not None, "Response should not be None" + assert hasattr(response, "candidates"), "Response should have candidates" + assert len(response.candidates) > 0, "Should have at least one candidate" + + candidate = response.candidates[0] + assert hasattr(candidate, "content"), "Candidate should have content" + assert hasattr(candidate.content, "parts"), "Content should have parts" + + # Check for thoughts in usage metadata + has_thoughts = False + thoughts_token_count = 0 + + if hasattr(response, "usage_metadata"): + usage = response.usage_metadata + if hasattr(usage, "thoughts_token_count"): + thoughts_token_count = usage.thoughts_token_count + has_thoughts = thoughts_token_count > 0 + print(f"Found thoughts with {thoughts_token_count} tokens") + + # Should have thinking/thoughts tokens + assert has_thoughts, ( + f"Response should contain thoughts/reasoning tokens. " + f"Usage metadata: {response.usage_metadata if hasattr(response, 'usage_metadata') else 'None'}" + ) + assert thoughts_token_count > 0, "Thoughts token count should be greater than 0" + + # Validate that we have a response (even if thoughts aren't directly visible in parts) + # In Gemini, thoughts are counted but may not be directly exposed in the response + regular_text = "" + for part in candidate.content.parts: + if hasattr(part, "text") and part.text: + regular_text += part.text + + # Should have regular response text + assert len(regular_text) > 0, "Should have regular response text" + + print(f"✓ Thoughts used {thoughts_token_count} tokens") + print(f"✓ Response content: {regular_text[:200]}...") + + # Validate the response makes sense for the problem + response_lower = regular_text.lower() + reasoning_keywords = [ + "egg", "milk", "chicken", "cow", "profit", + "cost", "revenue", "week", "calculate", "total" + ] + + keyword_matches = sum( + 1 for keyword in reasoning_keywords if keyword in response_lower + ) + assert keyword_matches >= 3, ( + f"Response should address the farmer problem. " + f"Found {keyword_matches} keywords. Content: {regular_text[:200]}..." + ) + + @skip_if_no_api_key("google") + def test_27_extended_thinking_streaming(self, google_client, test_config): + """Test Case 27: Extended thinking/reasoning (streaming)""" + from google.genai import types + + # Convert to Google GenAI message format + messages = GEMINI_REASONING_STREAMING_PROMPT[0]["content"] + + # Stream with thinking enabled + stream = google_client.models.generate_content_stream( + model=get_model("google", "chat"), + contents=messages, + config=types.GenerateContentConfig( + thinking_config=types.ThinkingConfig( + include_thoughts=True, + thinking_budget=5000, + ), + max_output_tokens=800, + ), + ) + + # Collect streaming content + text_parts = [] + chunk_count = 0 + final_usage = None + + for chunk in stream: + chunk_count += 1 + + # Collect text content + if hasattr(chunk, "candidates") and chunk.candidates is not None and len(chunk.candidates) > 0: + candidate = chunk.candidates[0] + if hasattr(candidate, "content") and hasattr(candidate.content, "parts") and candidate.content.parts: + for part in candidate.content.parts: + if hasattr(part, "text") and part.text: + text_parts.append(part.text) + + # Capture final usage metadata + if hasattr(chunk, "usage_metadata"): + final_usage = chunk.usage_metadata + + # Safety check + if chunk_count > 500: + break + + # Combine collected content + complete_text = "".join(text_parts) + + # Validate results + assert chunk_count > 0, "Should receive at least one chunk" + assert final_usage is not None, "Should have usage metadata" + + # Check for thoughts in usage metadata + has_thoughts = False + thoughts_token_count = 0 + + if hasattr(final_usage, "thoughts_token_count"): + thoughts_token_count = final_usage.thoughts_token_count + has_thoughts = thoughts_token_count > 0 + + assert has_thoughts, ( + f"Should detect thinking in streaming. " + f"Usage metadata: {final_usage}" + ) + assert thoughts_token_count > 0, ( + f"Should have substantial thinking tokens, got {thoughts_token_count}. " + f"Text parts: {len(text_parts)}" + ) + + # Should have regular response text too + assert len(complete_text) > 0, "Should have regular response text" + + # Validate thinking content + text_lower = complete_text.lower() + library_keywords = [ + "book", "library", "lent", "return", "donation", + "total", "available", "inventory", "calculate", "percent" + ] + + keyword_matches = sum( + 1 for keyword in library_keywords if keyword in text_lower + ) + assert keyword_matches >= 3, ( + f"Response should reason about the library problem. " + f"Found {keyword_matches} keywords. Content: {complete_text[:200]}..." + ) + + print(f"✓ Streamed with thinking ({thoughts_token_count} thought tokens)") + print(f"✓ Streamed response ({len(text_parts)} chunks): {complete_text[:150]}...") + # Additional helper functions specific to Google GenAI def extract_google_function_calls(response: Any) -> List[Dict[str, Any]]: diff --git a/tests/integrations/tests/utils/common.py b/tests/integrations/tests/utils/common.py index 6459369e7..85adcc360 100644 --- a/tests/integrations/tests/utils/common.py +++ b/tests/integrations/tests/utils/common.py @@ -222,6 +222,29 @@ class Config: } ] +# Gemini Reasoning Test Prompts +GEMINI_REASONING_PROMPT = [ + { + "role": "user", + "content": ( + "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. " + "If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, " + "what is the farmer's weekly profit? Please show your step-by-step reasoning." + ), + } +] + +GEMINI_REASONING_STREAMING_PROMPT = [ + { + "role": "user", + "content": ( + "A library has 1200 books. In January, they lent out 40% of their books. In February, they got 150 books returned and lent out 200 new books. " + "In March, they received 80 new books as donations and lent out 25% of their current inventory. " + "How many books does the library have available at the end of March? Think through this step by step." + ), + } +] + IMAGE_URL_MESSAGES = [ { "role": "user", diff --git a/transports/bifrost-http/integrations/genai.go b/transports/bifrost-http/integrations/genai.go index 3fe31b119..c8a4a6646 100644 --- a/transports/bifrost-http/integrations/genai.go +++ b/transports/bifrost-http/integrations/genai.go @@ -48,17 +48,17 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { }, nil } else { return &schemas.BifrostRequest{ - ChatRequest: geminiReq.ToBifrostChatRequest(), + ResponsesRequest: geminiReq.ToBifrostResponsesRequest(), }, nil } } return nil, errors.New("invalid request type") }, - EmbeddingResponseConverter: func(ctx *context.Context, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { + EmbeddingResponseConverter: func(ctx *context.Context, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { return gemini.ToGeminiEmbeddingResponse(resp), nil }, - ChatResponseConverter: func(ctx *context.Context, resp *schemas.BifrostChatResponse) (interface{}, error) { - return gemini.ToGeminiChatResponse(resp), nil + ResponsesResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesResponse) (interface{}, error) { + return gemini.ToGeminiResponsesResponse(resp), nil }, SpeechResponseConverter: func(ctx *context.Context, resp *schemas.BifrostSpeechResponse) (interface{}, error) { return gemini.ToGeminiSpeechResponse(resp), nil @@ -70,8 +70,13 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { return gemini.ToGeminiError(err) }, StreamConfig: &StreamConfig{ - ChatStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostChatResponse) (string, interface{}, error) { - return "", gemini.ToGeminiChatResponse(resp), nil + ResponsesStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { + geminiResponse := gemini.ToGeminiResponsesStreamResponse(resp) + // Skip lifecycle events with no Gemini equivalent + if geminiResponse == nil { + return "", nil, nil + } + return "", geminiResponse, nil }, ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { return gemini.ToGeminiError(err)