diff --git a/core/internal/testutil/account.go b/core/internal/testutil/account.go index 848101e5f..53ca03c58 100644 --- a/core/internal/testutil/account.go +++ b/core/internal/testutil/account.go @@ -139,6 +139,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", }, }, }, diff --git a/core/internal/testutil/chat_completion_stream.go b/core/internal/testutil/chat_completion_stream.go index 2f9385c42..e4c7ed973 100644 --- a/core/internal/testutil/chat_completion_stream.go +++ b/core/internal/testutil/chat_completion_stream.go @@ -355,4 +355,355 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont t.Logf("✅ Streaming with tools test completed successfully") }) } + + // Test chat completion streaming with reasoning if supported + if testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != "" { + t.Run("ChatCompletionStreamWithReasoning", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + problemPrompt := "Solve this step by step: If a train leaves station A at 2 PM traveling at 60 mph, and another train leaves station B at 3 PM traveling at 80 mph toward station A, and the stations are 420 miles apart, when will they meet?" + + messages := []schemas.ChatMessage{ + CreateBasicChatMessage(problemPrompt), + } + + request := &schemas.BifrostChatRequest{ + Provider: testConfig.Provider, + Model: testConfig.ReasoningModel, + Input: messages, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(1800), + Reasoning: &schemas.ChatReasoning{ + Effort: bifrost.Ptr("high"), + MaxTokens: bifrost.Ptr(1500), + }, + }, + Fallbacks: testConfig.Fallbacks, + } + + // Use retry framework for stream requests with reasoning + retryConfig := StreamingRetryConfig() + retryContext := TestRetryContext{ + ScenarioName: "ChatCompletionStreamWithReasoning", + ExpectedBehavior: map[string]interface{}{ + "should_stream_reasoning": true, + "should_have_reasoning_events": true, + "problem_type": "mathematical", + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ReasoningModel, + "reasoning": true, + }, + } + + // Use proper streaming retry wrapper for the stream request + responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return client.ChatCompletionStreamRequest(ctx, request) + }) + + RequireNoError(t, err, "Chat completion stream with reasoning failed") + if responseChannel == nil { + t.Fatal("Response channel should not be nil") + } + + var reasoningDetected bool + var reasoningDetailsDetected bool + var reasoningTokensDetected bool + var responseCount int + + streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second) + defer cancel() + + t.Logf("🧠 Testing chat completion streaming with reasoning...") + + for { + select { + case response, ok := <-responseChannel: + if !ok { + goto reasoningStreamComplete + } + + if response == nil { + t.Fatal("Streaming response should not be nil") + } + responseCount++ + + if response.BifrostChatResponse != nil { + chatResp := response.BifrostChatResponse + + // Check for reasoning in choices + if len(chatResp.Choices) > 0 { + for _, choice := range chatResp.Choices { + if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { + delta := choice.ChatStreamResponseChoice.Delta + + // Check for reasoning content in delta + if delta.Reasoning != nil && *delta.Reasoning != "" { + reasoningDetected = true + t.Logf("🧠 Reasoning content detected: %q", *delta.Reasoning) + } + + // Check for reasoning details in delta + if len(delta.ReasoningDetails) > 0 { + reasoningDetailsDetected = true + t.Logf("🧠 Reasoning details detected: %d entries", len(delta.ReasoningDetails)) + + for _, detail := range delta.ReasoningDetails { + t.Logf(" - Type: %s, Index: %d", detail.Type, detail.Index) + switch detail.Type { + case schemas.BifrostReasoningDetailsTypeText: + if detail.Text != nil && *detail.Text != "" { + maxLen := 100 + text := *detail.Text + if len(text) < maxLen { + maxLen = len(text) + } + t.Logf(" Text preview: %q", text[:maxLen]) + } + case schemas.BifrostReasoningDetailsTypeSummary: + if detail.Summary != nil { + t.Logf(" Summary length: %d", len(*detail.Summary)) + } + case schemas.BifrostReasoningDetailsTypeEncrypted: + if detail.Data != nil { + t.Logf(" Encrypted data length: %d", len(*detail.Data)) + } + } + } + } + } + } + } + + // Check for reasoning tokens in usage (usually in final chunk) + if chatResp.Usage != nil && chatResp.Usage.CompletionTokensDetails != nil { + if chatResp.Usage.CompletionTokensDetails.ReasoningTokens > 0 { + reasoningTokensDetected = true + t.Logf("đŸ”ĸ Reasoning tokens used: %d", chatResp.Usage.CompletionTokensDetails.ReasoningTokens) + } + } + } + + if responseCount > 150 { + goto reasoningStreamComplete + } + + case <-streamCtx.Done(): + t.Fatal("Timeout waiting for chat completion streaming response with reasoning") + } + } + + reasoningStreamComplete: + if responseCount == 0 { + t.Fatal("Should receive at least one streaming response") + } + + // At least one of these should be detected for reasoning + if !reasoningDetected && !reasoningDetailsDetected && !reasoningTokensDetected { + t.Logf("âš ī¸ Warning: No explicit reasoning indicators found in streaming response") + } else { + t.Logf("✅ Reasoning indicators detected:") + if reasoningDetected { + t.Logf(" - Reasoning content found") + } + if reasoningDetailsDetected { + t.Logf(" - Reasoning details found") + } + if reasoningTokensDetected { + t.Logf(" - Reasoning tokens reported") + } + } + + t.Logf("✅ Chat completion streaming with reasoning test completed successfully") + }) + + // Additional test with full validation and retry support + t.Run("ChatCompletionStreamWithReasoningValidated", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + if testConfig.Provider == schemas.OpenAI || testConfig.Provider == schemas.Groq { + // OpenAI and Groq because reasoning for them in stream is extremely flaky + t.Skip("Skipping ChatCompletionStreamWithReasoningValidated test for OpenAI and Groq") + return + } + + problemPrompt := "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, what is the farmer's weekly profit?" + if testConfig.Provider == schemas.Cerebras { + problemPrompt = "Hello how are you, can you search hackernews news regarding maxim ai for me? use your tools for this" + } + + messages := []schemas.ChatMessage{ + CreateBasicChatMessage(problemPrompt), + } + + request := &schemas.BifrostChatRequest{ + Provider: testConfig.Provider, + Model: testConfig.ReasoningModel, + Input: messages, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(1800), + Reasoning: &schemas.ChatReasoning{ + Effort: bifrost.Ptr("high"), + MaxTokens: bifrost.Ptr(1500), + }, + }, + Fallbacks: testConfig.Fallbacks, + } + + // Use retry framework for stream requests with reasoning and validation + retryConfig := StreamingRetryConfig() + retryContext := TestRetryContext{ + ScenarioName: "ChatCompletionStreamWithReasoningValidated", + ExpectedBehavior: map[string]interface{}{ + "should_stream_reasoning": true, + "should_have_reasoning_indicators": true, + "problem_type": "mathematical", + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ReasoningModel, + "reasoning": true, + "validated": true, + }, + } + + // Use validation retry wrapper that includes stream reading and validation + validationResult := WithChatStreamValidationRetry( + t, + retryConfig, + retryContext, + func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return client.ChatCompletionStreamRequest(ctx, request) + }, + func(responseChannel chan *schemas.BifrostStream) ChatStreamValidationResult { + var reasoningDetected bool + var reasoningDetailsDetected bool + var reasoningTokensDetected bool + var responseCount int + var streamErrors []string + var fullContent strings.Builder + + streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second) + defer cancel() + + t.Logf("🧠 Testing validated chat completion streaming with reasoning...") + + for { + select { + case response, ok := <-responseChannel: + if !ok { + goto validatedReasoningStreamComplete + } + + if response == nil { + streamErrors = append(streamErrors, "❌ Streaming response should not be nil") + continue + } + responseCount++ + + if response.BifrostChatResponse != nil { + chatResp := response.BifrostChatResponse + + // Check for reasoning in choices + if len(chatResp.Choices) > 0 { + for _, choice := range chatResp.Choices { + if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { + delta := choice.ChatStreamResponseChoice.Delta + + // Accumulate content + if delta.Content != nil { + fullContent.WriteString(*delta.Content) + t.Logf("📝 Content chunk received (length: %d, total so far: %d)", len(*delta.Content), fullContent.Len()) + } + + // Check for reasoning content in delta + if delta.Reasoning != nil && *delta.Reasoning != "" { + reasoningDetected = true + t.Logf("🧠 Reasoning content detected (length: %d)", len(*delta.Reasoning)) + } + + // Check for reasoning details in delta + if len(delta.ReasoningDetails) > 0 { + reasoningDetailsDetected = true + t.Logf("🧠 Reasoning details detected: %d entries", len(delta.ReasoningDetails)) + } + } + } + } + + // Check for reasoning tokens in usage + if chatResp.Usage != nil && chatResp.Usage.CompletionTokensDetails != nil { + if chatResp.Usage.CompletionTokensDetails.ReasoningTokens > 0 { + reasoningTokensDetected = true + t.Logf("đŸ”ĸ Reasoning tokens: %d", chatResp.Usage.CompletionTokensDetails.ReasoningTokens) + } + } + } + + if responseCount > 150 { + goto validatedReasoningStreamComplete + } + + case <-streamCtx.Done(): + streamErrors = append(streamErrors, "❌ Timeout waiting for streaming response with reasoning") + goto validatedReasoningStreamComplete + } + } + + validatedReasoningStreamComplete: + var errors []string + if responseCount == 0 { + errors = append(errors, "❌ Should receive at least one streaming response") + } + + // Check if at least one reasoning indicator is present + hasAnyReasoningIndicator := reasoningDetected || reasoningDetailsDetected || reasoningTokensDetected + if !hasAnyReasoningIndicator { + errors = append(errors, fmt.Sprintf("❌ No reasoning indicators found in streaming response (received %d chunks)", responseCount)) + } + + // Check content - for reasoning models, content may come after reasoning or may not be present + // If reasoning is detected, we consider it a valid response even without content + content := strings.TrimSpace(fullContent.String()) + if content == "" && !hasAnyReasoningIndicator { + // Only require content if no reasoning indicators were found + errors = append(errors, "❌ No content received in streaming response and no reasoning indicators found") + } else if content == "" && hasAnyReasoningIndicator { + // Log a warning but don't fail if reasoning is present + t.Logf("âš ī¸ Warning: Reasoning detected but no content chunks received (this may be expected for some reasoning models)") + } + + if len(streamErrors) > 0 { + errors = append(errors, streamErrors...) + } + + return ChatStreamValidationResult{ + Passed: len(errors) == 0, + Errors: errors, + ReceivedData: responseCount > 0 && (content != "" || hasAnyReasoningIndicator), + StreamErrors: streamErrors, + ToolCallDetected: false, // Not testing tool calls here + ResponseCount: responseCount, + } + }, + ) + + // Check validation result + if !validationResult.Passed { + allErrors := append(validationResult.Errors, validationResult.StreamErrors...) + t.Fatalf("❌ Chat completion stream with reasoning validation failed after retries: %s", strings.Join(allErrors, "; ")) + } + + if validationResult.ResponseCount == 0 { + t.Fatalf("❌ Should receive at least one streaming response") + } + + t.Logf("✅ Validated chat completion streaming with reasoning test completed successfully") + }) + } } diff --git a/core/internal/testutil/reasoning.go b/core/internal/testutil/reasoning.go index 5adff016d..17b7c4583 100644 --- a/core/internal/testutil/reasoning.go +++ b/core/internal/testutil/reasoning.go @@ -9,8 +9,8 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -// RunReasoningTest executes the reasoning test scenario to test thinking capabilities via Responses API only -func RunReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { +// RunResponsesReasoningTest executes the reasoning test scenario to test thinking capabilities via Responses API only +func RunResponsesReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { if !testConfig.Scenarios.Reasoning { t.Logf("â­ī¸ Reasoning not supported for provider %s", testConfig.Provider) return @@ -22,7 +22,7 @@ func RunReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context return } - t.Run("Reasoning", func(t *testing.T) { + t.Run("ResponsesReasoning", func(t *testing.T) { if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { t.Parallel() } @@ -40,7 +40,7 @@ func RunReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context Model: testConfig.ReasoningModel, Input: responsesMessages, Params: &schemas.ResponsesParameters{ - MaxOutputTokens: bifrost.Ptr(800), + MaxOutputTokens: bifrost.Ptr(1800), // Configure reasoning-specific parameters Reasoning: &schemas.ResponsesParametersReasoning{ Effort: bifrost.Ptr("high"), // High effort for complex reasoning @@ -198,6 +198,218 @@ func validateResponsesAPIReasoning(t *testing.T, response *schemas.BifrostRespon return detected } +// RunChatCompletionReasoningTest executes the reasoning test scenario to test thinking capabilities via Chat Completions API +func RunChatCompletionReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.Reasoning { + t.Logf("â­ī¸ Reasoning not supported for provider %s", testConfig.Provider) + return + } + + // Skip if no reasoning model is configured + if testConfig.ReasoningModel == "" { + t.Logf("â­ī¸ No reasoning model configured for provider %s", testConfig.Provider) + return + } + + t.Run("ChatCompletionReasoning", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + if testConfig.Provider == schemas.OpenAI { + // OpenAI because reasoning for them in chat completions is extremely flaky + t.Skip("Skipping ChatCompletionReasoning test for OpenAI") + return + } + + // Create a complex problem that requires step-by-step reasoning + problemPrompt := "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, what is the farmer's weekly profit? Please show your step-by-step reasoning." + + chatMessages := []schemas.ChatMessage{ + CreateBasicChatMessage(problemPrompt), + } + + // Execute Chat Completions API test with retries + chatReq := &schemas.BifrostChatRequest{ + Provider: testConfig.Provider, + Model: testConfig.ReasoningModel, + Input: chatMessages, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(1800), + // Configure reasoning-specific parameters + Reasoning: &schemas.ChatReasoning{ + Effort: bifrost.Ptr("high"), // High effort for complex reasoning + MaxTokens: bifrost.Ptr(1500), // Maximum tokens for reasoning output + }, + }, + Fallbacks: testConfig.Fallbacks, + } + + // Use retry framework with enhanced validation for reasoning + retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig) + retryContext := TestRetryContext{ + ScenarioName: "Reasoning", + ExpectedBehavior: map[string]interface{}{ + "should_show_reasoning": true, + "mathematical_problem": true, + "step_by_step": true, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ReasoningModel, + "problem_type": "mathematical", + "complexity": "high", + "expects_reasoning": true, + }, + } + chatRetryConfig := ChatRetryConfig{ + MaxAttempts: retryConfig.MaxAttempts, + BaseDelay: retryConfig.BaseDelay, + MaxDelay: retryConfig.MaxDelay, + Conditions: []ChatRetryCondition{}, // Add specific chat retry conditions as needed + OnRetry: retryConfig.OnRetry, + OnFinalFail: retryConfig.OnFinalFail, + } + + // Enhanced validation for reasoning scenarios + expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{ + "requires_reasoning": true, + }) + expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) + + response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Reasoning", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return client.ChatCompletionRequest(ctx, chatReq) + }) + + if chatError != nil { + t.Fatalf("❌ Reasoning test failed after retries: %v", GetErrorMessage(chatError)) + } + + // Log the response content + chatContent := GetChatContent(response) + if chatContent == "" { + t.Logf("✅ Chat Completions API reasoning result: ") + } else { + maxLen := 300 + if len(chatContent) < maxLen { + maxLen = len(chatContent) + } + t.Logf("✅ Chat Completions API reasoning result: %s", chatContent[:maxLen]) + } + + // Additional reasoning-specific validation (complementary to the main validation) + reasoningDetected := validateChatCompletionReasoning(t, response) + if !reasoningDetected { + t.Logf("âš ī¸ No explicit reasoning indicators found in response structure - may still contain valid reasoning in content") + } else { + t.Logf("🧠 Reasoning structure detected in response") + } + + t.Logf("🎉 Chat Completions API passed Reasoning test!") + }) +} + +// validateChatCompletionReasoning performs additional validation specific to Chat Completions API reasoning features +// Returns true if reasoning indicators are found +func validateChatCompletionReasoning(t *testing.T, response *schemas.BifrostChatResponse) bool { + if response == nil || len(response.Choices) == 0 { + return false + } + + reasoningFound := false + reasoningDetailsFound := false + reasoningTokensFound := false + + // Check each choice for reasoning indicators + for _, choice := range response.Choices { + // Check for reasoning details in ChatNonStreamResponseChoice + if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil { + message := choice.ChatNonStreamResponseChoice.Message + + if message == nil { + continue + } + + // Check for reasoning content in message (for backward compatibility) + if message.ChatAssistantMessage != nil && message.ChatAssistantMessage.Reasoning != nil && *message.ChatAssistantMessage.Reasoning != "" { + reasoningFound = true + t.Logf("🧠 Found reasoning content in message (length: %d)", len(*message.ChatAssistantMessage.Reasoning)) + + // Log first 200 chars for debugging + reasoningText := *message.ChatAssistantMessage.Reasoning + maxLen := 200 + if len(reasoningText) < maxLen { + maxLen = len(reasoningText) + } + t.Logf("📋 First reasoning content: %s", reasoningText[:maxLen]) + } + + // Check for reasoning details array + if message.ChatAssistantMessage != nil && len(message.ChatAssistantMessage.ReasoningDetails) > 0 { + reasoningDetailsFound = true + t.Logf("📝 Found %d reasoning details entries", len(message.ChatAssistantMessage.ReasoningDetails)) + + // Log details about each reasoning entry + for i, detail := range message.ChatAssistantMessage.ReasoningDetails { + t.Logf(" - Entry %d: Type=%s, Index=%d", i, detail.Type, detail.Index) + + switch detail.Type { + case schemas.BifrostReasoningDetailsTypeSummary: + if detail.Summary != nil { + t.Logf(" Summary length: %d", len(*detail.Summary)) + } + case schemas.BifrostReasoningDetailsTypeText: + if detail.Text != nil { + textLen := len(*detail.Text) + t.Logf(" Text length: %d", textLen) + if textLen > 0 { + maxLen := 150 + if textLen < maxLen { + maxLen = textLen + } + t.Logf(" Text preview: %s", (*detail.Text)[:maxLen]) + } + } + case schemas.BifrostReasoningDetailsTypeEncrypted: + if detail.Data != nil { + t.Logf(" Encrypted data length: %d", len(*detail.Data)) + } + if detail.Signature != nil { + t.Logf(" Signature present: %d bytes", len(*detail.Signature)) + } + } + } + } + } + } + + // Check if reasoning tokens were used + if response.Usage != nil && response.Usage.CompletionTokensDetails != nil && + response.Usage.CompletionTokensDetails.ReasoningTokens > 0 { + reasoningTokensFound = true + t.Logf("đŸ”ĸ Reasoning tokens used: %d", response.Usage.CompletionTokensDetails.ReasoningTokens) + } + + // Log findings + detected := reasoningFound || reasoningDetailsFound || reasoningTokensFound + if detected { + t.Logf("✅ Chat Completions API reasoning indicators detected") + if reasoningFound { + t.Logf(" - Reasoning content found in message") + } + if reasoningDetailsFound { + t.Logf(" - Reasoning details array found") + } + if reasoningTokensFound { + t.Logf(" - Reasoning tokens usage reported") + } + } else { + t.Logf("â„šī¸ No explicit reasoning indicators found (may be provider-specific)") + } + + return detected +} + // min returns the smaller of two integers func min(a, b int) int { if a < b { diff --git a/core/internal/testutil/responses_stream.go b/core/internal/testutil/responses_stream.go index 9260ad667..ae3644b0a 100644 --- a/core/internal/testutil/responses_stream.go +++ b/core/internal/testutil/responses_stream.go @@ -437,7 +437,7 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Model: testConfig.ReasoningModel, Input: messages, Params: &schemas.ResponsesParameters{ - MaxOutputTokens: bifrost.Ptr(400), + MaxOutputTokens: bifrost.Ptr(1800), Reasoning: &schemas.ResponsesParametersReasoning{ Effort: bifrost.Ptr("high"), // Summary: bifrost.Ptr("detailed"), diff --git a/core/internal/testutil/tests.go b/core/internal/testutil/tests.go index d9db0cae2..e6ab78fe2 100644 --- a/core/internal/testutil/tests.go +++ b/core/internal/testutil/tests.go @@ -46,7 +46,8 @@ func RunAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context RunTranscriptionStreamTest, RunTranscriptionStreamAdvancedTest, RunEmbeddingTest, - RunReasoningTest, + RunChatCompletionReasoningTest, + RunResponsesReasoningTest, RunListModelsTest, RunListModelsPaginationTest, RunPromptCachingTest, @@ -85,7 +86,8 @@ func printTestSummary(t *testing.T, testConfig ComprehensiveTestConfig) { {"Transcription", testConfig.Scenarios.Transcription}, {"TranscriptionStream", testConfig.Scenarios.TranscriptionStream}, {"Embedding", testConfig.Scenarios.Embedding && testConfig.EmbeddingModel != ""}, - {"Reasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""}, + {"ChatCompletionReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""}, + {"ResponsesReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""}, {"ListModels", testConfig.Scenarios.ListModels}, {"PromptCaching", testConfig.Scenarios.SimpleChat && testConfig.PromptCachingModel != ""}, } diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index 1dd1ac8cb..aa1dfa1fd 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -674,18 +674,13 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke return nil, err } - // Convert to Anthropic format using the centralized converter - jsonData, err := providerUtils.CheckContextAndGetRequestBody( - ctx, - request, - func() (any, error) { return ToAnthropicResponsesRequest(request) }, - provider.GetProviderKey()) + jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), false) if err != nil { return nil, err } // Use struct directly for JSON marshaling - responseBody, latency, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value) + responseBody, latency, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value) if err != nil { return nil, err } @@ -694,7 +689,7 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke response := AcquireAnthropicMessageResponse() defer ReleaseAnthropicMessageResponse(response) - rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) if bifrostErr != nil { return nil, bifrostErr } @@ -728,20 +723,9 @@ func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHook } // Convert to Anthropic format using the centralized converter - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( - ctx, - request, - func() (any, error) { - reqBody, err := ToAnthropicResponsesRequest(request) - if err != nil { - return nil, err - } - reqBody.Stream = schemas.Ptr(true) - return reqBody, nil - }, - provider.GetProviderKey()) - if bifrostErr != nil { - return nil, bifrostErr + jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), true) + if err != nil { + return nil, err } // Prepare Anthropic headers diff --git a/core/providers/anthropic/chat.go b/core/providers/anthropic/chat.go index c2e0929d3..6567fbfc7 100644 --- a/core/providers/anthropic/chat.go +++ b/core/providers/anthropic/chat.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -98,20 +99,26 @@ func ToAnthropicChatRequest(bifrostReq *schemas.BifrostChatRequest) (*AnthropicM // Convert reasoning if bifrostReq.Params.Reasoning != nil { - if bifrostReq.Params.Reasoning.Effort != nil && *bifrostReq.Params.Reasoning.Effort == "none" { + if bifrostReq.Params.Reasoning.MaxTokens != nil { + if *bifrostReq.Params.Reasoning.MaxTokens < MinimumReasoningMaxTokens { + return nil, fmt.Errorf("reasoning.max_tokens must be >= %d for anthropic", MinimumReasoningMaxTokens) + } anthropicReq.Thinking = &AnthropicThinking{ - Type: "disabled", + Type: "enabled", + BudgetTokens: bifrostReq.Params.Reasoning.MaxTokens, + } + } else if bifrostReq.Params.Reasoning.Effort != nil && *bifrostReq.Params.Reasoning.Effort != "none" { + budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, MinimumReasoningMaxTokens, anthropicReq.MaxTokens) + if err != nil { + return nil, err + } + anthropicReq.Thinking = &AnthropicThinking{ + Type: "enabled", + BudgetTokens: schemas.Ptr(budgetTokens), } } else { - if bifrostReq.Params.Reasoning.MaxTokens == nil { - return nil, fmt.Errorf("reasoning.max_tokens is required for reasoning") - } else if *bifrostReq.Params.Reasoning.MaxTokens < MinimumReasoningMaxTokens { - return nil, fmt.Errorf("reasoning.max_tokens must be greater than or equal to %d", MinimumReasoningMaxTokens) - } else { - anthropicReq.Thinking = &AnthropicThinking{ - Type: "enabled", - BudgetTokens: bifrostReq.Params.Reasoning.MaxTokens, - } + anthropicReq.Thinking = &AnthropicThinking{ + Type: "disabled", } } } @@ -601,7 +608,7 @@ func (chunk *AnthropicStreamEvent) ToBifrostChatCompletionStream() (*schemas.Bif case AnthropicStreamDeltaTypeInputJSON: // Handle tool use streaming - accumulate partial JSON - if chunk.Delta.PartialJSON != nil && *chunk.Delta.PartialJSON != "" { + if chunk.Delta.PartialJSON != nil { // Create streaming response for tool input delta streamResponse := &schemas.BifrostChatResponse{ Object: "chat.completion.chunk", @@ -630,6 +637,7 @@ func (chunk *AnthropicStreamEvent) ToBifrostChatCompletionStream() (*schemas.Bif case AnthropicStreamDeltaTypeThinking: // Handle thinking content streaming if chunk.Delta.Thinking != nil && *chunk.Delta.Thinking != "" { + thinkingText := *chunk.Delta.Thinking // Create streaming response for thinking delta streamResponse := &schemas.BifrostChatResponse{ Object: "chat.completion.chunk", @@ -638,12 +646,12 @@ func (chunk *AnthropicStreamEvent) ToBifrostChatCompletionStream() (*schemas.Bif Index: 0, ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ Delta: &schemas.ChatStreamResponseChoiceDelta{ - Reasoning: chunk.Delta.Thinking, + Reasoning: schemas.Ptr(thinkingText), ReasoningDetails: []schemas.ChatReasoningDetails{ { Index: 0, Type: schemas.BifrostReasoningDetailsTypeText, - Text: chunk.Delta.Thinking, + Text: schemas.Ptr(thinkingText), }, }, }, diff --git a/core/providers/anthropic/errors.go b/core/providers/anthropic/errors.go index 58ea8fcc9..bb5180ac1 100644 --- a/core/providers/anthropic/errors.go +++ b/core/providers/anthropic/errors.go @@ -1,6 +1,9 @@ package anthropic import ( + "encoding/json" + "fmt" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" @@ -36,6 +39,24 @@ func ToAnthropicChatCompletionError(bifrostErr *schemas.BifrostError) *Anthropic } } +// ToAnthropicResponsesStreamError converts a BifrostError to Anthropic responses streaming error in SSE format +func ToAnthropicResponsesStreamError(bifrostErr *schemas.BifrostError) string { + if bifrostErr == nil { + return "" + } + + anthropicErr := ToAnthropicChatCompletionError(bifrostErr) + + // Marshal to JSON + jsonData, err := json.Marshal(anthropicErr) + if err != nil { + return "" + } + + // Format as Anthropic SSE error event + return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData) +} + func parseAnthropicError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp AnthropicError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index e92bdd5e3..ec6516819 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -9,7 +9,10 @@ import ( "sync" "time" + "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" + + providerUtils "github.com/maximhq/bifrost/core/providers/utils" ) // AnthropicResponsesStreamState tracks state during streaming conversion for responses API @@ -27,9 +30,11 @@ type AnthropicResponsesStreamState struct { ItemIDs map[int]string // Maps output_index to item ID for stable IDs ReasoningSignatures map[int]string // Maps output_index to reasoning signature TextContentIndices map[int]bool // Tracks which content indices are text blocks + ReasoningContentIndices map[int]bool // Tracks which content indices are reasoning blocks CurrentOutputIndex int // Current output index counter MessageID *string // Message ID from message_start Model *string // Model name from message_start + StopReason *string // Stop reason for the message CreatedAt int // Timestamp for created_at consistency HasEmittedCreated bool // Whether we've emitted response.created HasEmittedInProgress bool // Whether we've emitted response.in_progress @@ -45,6 +50,7 @@ var anthropicResponsesStreamStatePool = sync.Pool{ ItemIDs: make(map[int]string), ReasoningSignatures: make(map[int]string), TextContentIndices: make(map[int]bool), + ReasoningContentIndices: make(map[int]bool), CurrentOutputIndex: 0, CreatedAt: int(time.Now().Unix()), HasEmittedCreated: false, @@ -88,12 +94,18 @@ func acquireAnthropicResponsesStreamState() *AnthropicResponsesStreamState { } else { clear(state.TextContentIndices) } + if state.ReasoningContentIndices == nil { + state.ReasoningContentIndices = make(map[int]bool) + } else { + clear(state.ReasoningContentIndices) + } // Reset other fields state.ChunkIndex = nil state.AccumulatedJSON = "" state.ComputerToolID = nil state.CurrentOutputIndex = 0 state.MessageID = nil + state.StopReason = nil state.Model = nil state.CreatedAt = int(time.Now().Unix()) state.HasEmittedCreated = false @@ -120,8 +132,10 @@ func (state *AnthropicResponsesStreamState) flush() { state.ItemIDs = make(map[int]string) state.ReasoningSignatures = make(map[int]string) state.TextContentIndices = make(map[int]bool) + state.ReasoningContentIndices = make(map[int]bool) state.CurrentOutputIndex = 0 state.MessageID = nil + state.StopReason = nil state.Model = nil state.CreatedAt = int(time.Now().Unix()) state.HasEmittedCreated = false @@ -148,165 +162,1274 @@ func (state *AnthropicResponsesStreamState) getOrCreateOutputIndex(contentIndex return outputIndex } -// ToBifrostResponsesRequest converts an Anthropic message request to Bifrost format -func (request *AnthropicMessageRequest) ToBifrostResponsesRequest() *schemas.BifrostResponsesRequest { - provider, model := schemas.ParseModelString(request.Model, schemas.Anthropic) - - bifrostReq := &schemas.BifrostResponsesRequest{ - Provider: provider, - Model: model, - Fallbacks: schemas.ParseFallbacks(request.Fallbacks), - } +// ToBifrostResponsesStream converts an Anthropic stream event to a Bifrost Responses Stream response +// It maintains state via the state for handling multi-chunk conversions like computer tools +// Returns a slice of responses to support cases where a single event produces multiple responses +func (chunk *AnthropicStreamEvent) ToBifrostResponsesStream(ctx context.Context, sequenceNumber int, state *AnthropicResponsesStreamState) ([]*schemas.BifrostResponsesStreamResponse, *schemas.BifrostError, bool) { + switch chunk.Type { + case AnthropicStreamEventTypeMessageStart: + // Message start - emit response.created and response.in_progress (OpenAI-style lifecycle) + if chunk.Message != nil { + state.MessageID = &chunk.Message.ID + state.Model = &chunk.Message.Model + // Use the state's CreatedAt for consistency + if state.CreatedAt == 0 { + state.CreatedAt = int(time.Now().Unix()) + } - // Convert basic parameters - params := &schemas.ResponsesParameters{ - ExtraParams: make(map[string]interface{}), - } + var responses []*schemas.BifrostResponsesStreamResponse - if request.MaxTokens > 0 { - params.MaxOutputTokens = &request.MaxTokens - } - if request.Temperature != nil { - params.Temperature = request.Temperature - } - if request.TopP != nil { - params.TopP = request.TopP - } - if request.Metadata != nil && request.Metadata.UserID != nil { - params.User = request.Metadata.UserID - } - if request.TopK != nil { - params.ExtraParams["top_k"] = *request.TopK - } - if request.StopSequences != nil { - params.ExtraParams["stop"] = request.StopSequences - } - if request.Thinking != nil { - params.ExtraParams["thinking"] = request.Thinking - } - if request.OutputFormat != nil { - params.Text = convertAnthropicOutputFormatToResponsesTextConfig(request.OutputFormat) - } - if request.Thinking != nil { - if request.Thinking.Type == "enabled" { - params.Reasoning = &schemas.ResponsesParametersReasoning{ - Effort: schemas.Ptr("auto"), - MaxTokens: request.Thinking.BudgetTokens, + // Emit response.created + if !state.HasEmittedCreated { + response := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, + } + if state.Model != nil { + response.Model = *state.Model + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCreated, + SequenceNumber: sequenceNumber, + Response: response, + }) + state.HasEmittedCreated = true } - } else { - params.Reasoning = &schemas.ResponsesParametersReasoning{ - Effort: schemas.Ptr("none"), + + // Emit response.in_progress + if !state.HasEmittedInProgress { + response := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, // Use same timestamp + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeInProgress, + SequenceNumber: sequenceNumber + len(responses), + Response: response, + }) + state.HasEmittedInProgress = true } - } - } - // Add trucation parameter if computer tool is being used - if provider == schemas.OpenAI && request.Tools != nil { - for _, tool := range request.Tools { - if tool.Type != nil && *tool.Type == AnthropicToolTypeComputer20250124 { - params.Truncation = schemas.Ptr("auto") - break + if len(responses) > 0 { + return responses, nil, false } } - } - bifrostReq.Params = params + case AnthropicStreamEventTypeContentBlockStart: + // Content block start - emit output_item.added (OpenAI-style) + if chunk.ContentBlock != nil && chunk.Index != nil { + outputIndex := state.getOrCreateOutputIndex(chunk.Index) - // Convert messages directly to ChatMessage format - var bifrostMessages []schemas.ResponsesMessage + if chunk.ContentBlock.Type == AnthropicContentBlockTypeToolUse && + chunk.ContentBlock.Name != nil && + *chunk.ContentBlock.Name == string(AnthropicToolNameComputer) && + chunk.ContentBlock.ID != nil { - // Handle system message - convert Anthropic system field to first message with role "system" - if request.System != nil { - var systemText string - if request.System.ContentStr != nil { - systemText = *request.System.ContentStr - } else if request.System.ContentBlocks != nil { - // Combine text blocks from system content - var textParts []string - for _, block := range request.System.ContentBlocks { - if block.Text != nil { - textParts = append(textParts, *block.Text) + // Start accumulating computer tool + state.ComputerToolID = chunk.ContentBlock.ID + state.ChunkIndex = chunk.Index + state.AccumulatedJSON = "" + + // Emit output_item.added for computer_call + item := &schemas.ResponsesMessage{ + ID: chunk.ContentBlock.ID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeComputerCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: chunk.ContentBlock.ID, + }, } - } - systemText = strings.Join(textParts, "\n") - } - if systemText != "" { - systemMsg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), - Content: &schemas.ResponsesMessageContent{ - ContentStr: &systemText, - }, + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }}, nil, false } - bifrostMessages = append(bifrostMessages, systemMsg) - } - } - - // Convert regular messages - for _, msg := range request.Messages { - convertedMessages := convertAnthropicMessageToBifrostResponsesMessages(&msg) - bifrostMessages = append(bifrostMessages, convertedMessages...) - } - // Convert tools if present - if request.Tools != nil { - var bifrostTools []schemas.ResponsesTool - for _, tool := range request.Tools { - bifrostTool := convertAnthropicToolToBifrost(&tool) - if bifrostTool != nil { - bifrostTools = append(bifrostTools, *bifrostTool) - } - } - if len(bifrostTools) > 0 { - bifrostReq.Params.Tools = bifrostTools - } - } + switch chunk.ContentBlock.Type { + case AnthropicContentBlockTypeText: + // Text block - emit output_item.added with type "message" + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant - if request.MCPServers != nil { - var bifrostMCPTools []schemas.ResponsesTool - for _, mcpServer := range request.MCPServers { - bifrostMCPTool := convertAnthropicMCPServerToBifrostTool(&mcpServer) - if bifrostMCPTool != nil { - bifrostMCPTools = append(bifrostMCPTools, *bifrostMCPTool) - } - } - if len(bifrostMCPTools) > 0 { - bifrostReq.Params.Tools = append(bifrostReq.Params.Tools, bifrostMCPTools...) - } - } + // Generate stable ID for text item + var itemID string + if state.MessageID == nil { + itemID = fmt.Sprintf("item_%d", outputIndex) + } else { + itemID = fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) + } + state.ItemIDs[outputIndex] = itemID - // Convert tool choice if present - if request.ToolChoice != nil { - bifrostToolChoice := convertAnthropicToolChoiceToBifrost(request.ToolChoice) - if bifrostToolChoice != nil { - bifrostReq.Params.ToolChoice = bifrostToolChoice - } - } + item := &schemas.ResponsesMessage{ + ID: &itemID, + Type: &messageType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{}, // Empty blocks slice for mutation support + }, + } - // Set the converted messages - if len(bifrostMessages) > 0 { - bifrostReq.Input = bifrostMessages - } + // Track that this content index is a text block + if chunk.Index != nil { + state.TextContentIndices[*chunk.Index] = true + } - return bifrostReq -} + var responses []*schemas.BifrostResponsesStreamResponse -// ToAnthropicResponsesRequest converts a BifrostRequest with Responses structure back to AnthropicMessageRequest -func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*AnthropicMessageRequest, error) { - if bifrostReq == nil { - return nil, fmt.Errorf("bifrost request is nil") - } + // Emit output_item.added + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }) - anthropicReq := &AnthropicMessageRequest{ - Model: bifrostReq.Model, - MaxTokens: AnthropicDefaultMaxTokens, - } + // Emit content_part.added with empty output_text part + emptyText := "" + part := &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &emptyText, + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + ItemID: &itemID, + Part: part, + }) - // Convert basic parameters - if bifrostReq.Params != nil { - if bifrostReq.Params.MaxOutputTokens != nil { - anthropicReq.MaxTokens = *bifrostReq.Params.MaxOutputTokens + return responses, nil, false + case AnthropicContentBlockTypeToolUse: + // Function call starting - emit output_item.added with type "function_call" and status "in_progress" + statusInProgress := "in_progress" + itemID := "" + if chunk.ContentBlock.ID != nil { + itemID = *chunk.ContentBlock.ID + state.ItemIDs[outputIndex] = itemID + } + item := &schemas.ResponsesMessage{ + ID: chunk.ContentBlock.ID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: &statusInProgress, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: chunk.ContentBlock.ID, + Name: chunk.ContentBlock.Name, + Arguments: schemas.Ptr(""), // Arguments will be filled by deltas + }, + } + + // Initialize argument buffer for this tool call + state.ToolArgumentBuffers[outputIndex] = "" + + // Mark tool use blocks to prevent synthetic content_part.added events + // This prevents extra content_block_stop events for tools like web_search + if chunk.Index != nil { + state.TextContentIndices[*chunk.Index] = false + } + + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }}, nil, false + case AnthropicContentBlockTypeMCPToolUse: + // MCP tool call starting - emit output_item.added + itemID := "" + if chunk.ContentBlock.ID != nil { + itemID = *chunk.ContentBlock.ID + state.ItemIDs[outputIndex] = itemID + } + item := &schemas.ResponsesMessage{ + ID: chunk.ContentBlock.ID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMCPCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + Name: chunk.ContentBlock.Name, + Arguments: schemas.Ptr(""), // Arguments will be filled by deltas + }, + } + + // Set server name if present + if chunk.ContentBlock.ServerName != nil { + item.ResponsesToolMessage.ResponsesMCPToolCall = &schemas.ResponsesMCPToolCall{ + ServerLabel: *chunk.ContentBlock.ServerName, + } + } + + // Initialize argument buffer for this MCP call and mark as MCP + state.ToolArgumentBuffers[outputIndex] = "" + state.MCPCallOutputIndices[outputIndex] = true + + // Mark MCP tool use blocks to prevent synthetic content_part.added events + if chunk.Index != nil { + state.TextContentIndices[*chunk.Index] = false + } + + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }}, nil, false + case AnthropicContentBlockTypeThinking: + // Thinking/reasoning block - emit output_item.added with type "reasoning" + messageType := schemas.ResponsesMessageTypeReasoning + role := schemas.ResponsesInputMessageRoleAssistant + + // Generate stable ID for reasoning item + var itemID string + if state.MessageID == nil { + itemID = fmt.Sprintf("reasoning_%d", outputIndex) + } else { + itemID = fmt.Sprintf("msg_%s_reasoning_%d", *state.MessageID, outputIndex) + } + state.ItemIDs[outputIndex] = itemID + + // Initialize reasoning structure + item := &schemas.ResponsesMessage{ + ID: &itemID, + Type: &messageType, + Role: &role, + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + }, + } + + // Track that this content index is a reasoning block + if chunk.Index != nil { + state.ReasoningContentIndices[*chunk.Index] = true + } + + var responses []*schemas.BifrostResponsesStreamResponse + + // Emit output_item.added + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }) + + // Emit content_part.added with empty reasoning_text part + emptyText := "" + part := &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: &emptyText, + } + // Preserve signature in the content part if present + if chunk.ContentBlock.Signature != nil { + part.Signature = chunk.ContentBlock.Signature + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + ItemID: &itemID, + Part: part, + }) + + return responses, nil, false + default: + // Send down an empty response only when integration type is anthropic + if ctx.Value(schemas.BifrostContextKeyIntegrationType) == "anthropic" { + return []*schemas.BifrostResponsesStreamResponse{{ + Type: "", + SequenceNumber: sequenceNumber, + }}, nil, false + } + return nil, nil, false + } + } + + case AnthropicStreamEventTypeContentBlockDelta: + if chunk.Index != nil && chunk.Delta != nil { + outputIndex := state.getOrCreateOutputIndex(chunk.Index) + + // Handle different delta types + switch chunk.Delta.Type { + case AnthropicStreamDeltaTypeText: + if chunk.Delta.Text != nil && *chunk.Delta.Text != "" { + // Text content delta - emit output_text.delta with item ID + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Delta: chunk.Delta.Text, + } + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + } + + case AnthropicStreamDeltaTypeInputJSON: + // Function call arguments delta + if chunk.Delta.PartialJSON != nil { + // Check if we're accumulating a computer tool + if state.ComputerToolID != nil && + state.ChunkIndex != nil && + *state.ChunkIndex == *chunk.Index { + // Accumulate the JSON and don't emit anything + state.AccumulatedJSON += *chunk.Delta.PartialJSON + return nil, nil, false + } + + // Accumulate tool arguments in buffer + if _, exists := state.ToolArgumentBuffers[outputIndex]; !exists { + state.ToolArgumentBuffers[outputIndex] = "" + } + state.ToolArgumentBuffers[outputIndex] += *chunk.Delta.PartialJSON + + // Emit appropriate delta type based on whether this is an MCP call + var deltaType schemas.ResponsesStreamResponseType + if state.MCPCallOutputIndices[outputIndex] { + deltaType = schemas.ResponsesStreamResponseTypeMCPCallArgumentsDelta + } else { + deltaType = schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta + } + + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: deltaType, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Delta: chunk.Delta.PartialJSON, + } + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + } + + case AnthropicStreamDeltaTypeThinking: + // Reasoning/thinking content delta + if chunk.Delta.Thinking != nil && *chunk.Delta.Thinking != "" { + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Delta: chunk.Delta.Thinking, + } + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + } + + case AnthropicStreamDeltaTypeSignature: + // Handle signature verification for thinking content + // Store the signature in state for the reasoning item + if chunk.Delta.Signature != nil && *chunk.Delta.Signature != "" { + state.ReasoningSignatures[outputIndex] = *chunk.Delta.Signature + // Emit signature_delta event using the signature field + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Signature: chunk.Delta.Signature, // Use signature field instead of delta + } + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + } + return nil, nil, false + } + } + + case AnthropicStreamEventTypeContentBlockStop: + // Content block is complete - emit output_item.done (OpenAI-style) + if chunk.Index != nil { + outputIndex := state.getOrCreateOutputIndex(chunk.Index) + + // Check if this is the end of a computer tool accumulation + if state.ComputerToolID != nil && + state.ChunkIndex != nil && + *state.ChunkIndex == *chunk.Index { + + // Parse accumulated JSON and convert to OpenAI format + var inputMap map[string]interface{} + var action *schemas.ResponsesComputerToolCallAction + + if state.AccumulatedJSON != "" { + if err := json.Unmarshal([]byte(state.AccumulatedJSON), &inputMap); err == nil { + action = convertAnthropicToResponsesComputerAction(inputMap) + } + } + + // Create computer_call item with action + statusCompleted := "completed" + item := &schemas.ResponsesMessage{ + ID: state.ComputerToolID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeComputerCall), + Status: &statusCompleted, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: state.ComputerToolID, + ResponsesComputerToolCall: &schemas.ResponsesComputerToolCall{ + PendingSafetyChecks: []schemas.ResponsesComputerToolCallPendingSafetyCheck{}, + }, + }, + } + + // Add action if we successfully parsed it + if action != nil { + item.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: action, + } + } + + // Clear computer tool state + state.ComputerToolID = nil + state.ChunkIndex = nil + state.AccumulatedJSON = "" + + // Return output_item.done + return []*schemas.BifrostResponsesStreamResponse{ + { + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }, + }, nil, false + } + + // Check if this is a text block - emit output_text.done and content_part.done + var responses []*schemas.BifrostResponsesStreamResponse + itemID := state.ItemIDs[outputIndex] + + // Check if this content index is a text block + if chunk.Index != nil { + if state.TextContentIndices[*chunk.Index] { + // Emit output_text.done (without accumulated text, just the event) + emptyText := "" + textDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Text: &emptyText, + } + if itemID != "" { + textDoneResponse.ItemID = &itemID + } + responses = append(responses, textDoneResponse) + + // Emit content_part.done + partDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + } + if itemID != "" { + partDoneResponse.ItemID = &itemID + } + responses = append(responses, partDoneResponse) + + // Clear the text content index tracking + delete(state.TextContentIndices, *chunk.Index) + } + + // Check if this content index is a reasoning block + if state.ReasoningContentIndices[*chunk.Index] { + // Emit reasoning_summary_text.done (reasoning equivalent of output_text.done) + emptyText := "" + reasoningDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Text: &emptyText, + } + if itemID != "" { + reasoningDoneResponse.ItemID = &itemID + } + responses = append(responses, reasoningDoneResponse) + + // Emit content_part.done for reasoning + partDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + } + if itemID != "" { + partDoneResponse.ItemID = &itemID + } + responses = append(responses, partDoneResponse) + + // Clear the reasoning content index tracking + delete(state.ReasoningContentIndices, *chunk.Index) + } + } + + // Check if this is a tool call (function_call or MCP call) + // If we have accumulated arguments, emit appropriate arguments.done first + if accumulatedArgs, hasArgs := state.ToolArgumentBuffers[outputIndex]; hasArgs && accumulatedArgs != "" { + // Emit appropriate arguments.done based on whether this is an MCP call + var doneType schemas.ResponsesStreamResponseType + if state.MCPCallOutputIndices[outputIndex] { + doneType = schemas.ResponsesStreamResponseTypeMCPCallArgumentsDone + } else { + doneType = schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone + } + + response := &schemas.BifrostResponsesStreamResponse{ + Type: doneType, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Arguments: &accumulatedArgs, + } + if itemID != "" { + response.ItemID = &itemID + } + responses = append(responses, response) + // Clear the buffer and MCP tracking + delete(state.ToolArgumentBuffers, outputIndex) + delete(state.MCPCallOutputIndices, outputIndex) + } + + // Emit output_item.done for all content blocks (text, tool, etc.) + statusCompleted := "completed" + doneItemID := state.ItemIDs[outputIndex] + doneItem := &schemas.ResponsesMessage{ + Status: &statusCompleted, + } + if doneItemID != "" { + doneItem.ID = &doneItemID + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: doneItem, + }) + + return responses, nil, false + } + + case AnthropicStreamEventTypeMessageDelta: + if chunk.Delta.StopReason != nil { + state.StopReason = schemas.Ptr(ConvertAnthropicFinishReasonToBifrost(*chunk.Delta.StopReason)) + } + // Check if integration type in ctx is anthropic + if ctx.Value(schemas.BifrostContextKeyIntegrationType) == "anthropic" { + // Convert usage from Anthropic format to Bifrost format + var bifrostUsage *schemas.ResponsesResponseUsage + if chunk.Usage != nil { + bifrostUsage = &schemas.ResponsesResponseUsage{ + InputTokens: chunk.Usage.InputTokens, + OutputTokens: chunk.Usage.OutputTokens, + TotalTokens: chunk.Usage.InputTokens + chunk.Usage.OutputTokens, + } + // Handle cached tokens if present + if chunk.Usage.CacheReadInputTokens > 0 { + bifrostUsage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ + CachedTokens: chunk.Usage.CacheReadInputTokens, + } + } + if chunk.Usage.CacheCreationInputTokens > 0 { + bifrostUsage.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{ + CachedTokens: chunk.Usage.CacheCreationInputTokens, + } + } + } + + // Convert stop reason if present + var stopReason *string + if chunk.Delta != nil && chunk.Delta.StopReason != nil { + converted := ConvertAnthropicFinishReasonToBifrost(*chunk.Delta.StopReason) + stopReason = &converted + } + + // Create response object with usage and stop reason + response := &schemas.BifrostResponsesResponse{ + CreatedAt: state.CreatedAt, + } + if state.MessageID != nil { + response.ID = state.MessageID + } + if state.Model != nil { + response.Model = *state.Model + } + if stopReason != nil { + response.StopReason = stopReason + } + if bifrostUsage != nil { + response.Usage = bifrostUsage + } + + return []*schemas.BifrostResponsesStreamResponse{{ + Type: "message_delta", + SequenceNumber: sequenceNumber, + Response: response, + }}, nil, false + } + // Message-level updates (like stop reason, usage, etc.) + // Note: We don't emit output_item.done here because items are already closed + // by content_block_stop. This event is informational only. + return nil, nil, false + + case AnthropicStreamEventTypeMessageStop: + // Message stop - emit response.completed (OpenAI-style) + response := &schemas.BifrostResponsesResponse{ + CreatedAt: state.CreatedAt, + } + if state.MessageID != nil { + response.ID = state.MessageID + } + if state.Model != nil { + response.Model = *state.Model + } + if state.StopReason != nil { + response.StopReason = state.StopReason + } + + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypeCompleted, + SequenceNumber: sequenceNumber, + Response: response, + }}, nil, true // Indicate stream is complete + + case AnthropicStreamEventTypePing: + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypePing, + SequenceNumber: sequenceNumber, + }}, nil, false + + case AnthropicStreamEventTypeError: + if chunk.Error != nil { + // Send error event + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Type: &chunk.Error.Type, + Message: chunk.Error.Message, + }, + } + + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypeError, + SequenceNumber: sequenceNumber, + Message: &chunk.Error.Message, + }}, bifrostErr, false + } + } + + return nil, nil, false +} + +// ToAnthropicResponsesStreamResponse converts a Bifrost Responses stream response to Anthropic SSE string format +func ToAnthropicResponsesStreamResponse(ctx context.Context, bifrostResp *schemas.BifrostResponsesStreamResponse) []*AnthropicStreamEvent { + if bifrostResp == nil { + return nil + } + + streamResp := &AnthropicStreamEvent{} + + // Map ResponsesStreamResponse types to Anthropic stream events + switch bifrostResp.Type { + case schemas.ResponsesStreamResponseTypeCreated: + // Only convert response.created back to message_start (not response.in_progress to avoid duplicates) + streamResp.Type = AnthropicStreamEventTypeMessageStart + if bifrostResp.Response != nil { + streamMessage := &AnthropicMessageResponse{ + Type: "message", + Role: "assistant", + Content: []AnthropicContentBlock{}, // Always empty array in message_start + Usage: &AnthropicUsage{ + InputTokens: 0, + OutputTokens: 0, + CacheReadInputTokens: 0, + CacheCreationInputTokens: 0, + CacheCreation: AnthropicUsageCacheCreation{ + Ephemeral5mInputTokens: 0, + Ephemeral1hInputTokens: 0, + }, + }, + } + if bifrostResp.Response.ID != nil { + streamMessage.ID = *bifrostResp.Response.ID + } + // Preserve model from Response if available, otherwise use ExtraFields + if bifrostResp.ExtraFields.ModelRequested != "" { + if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { + streamMessage.Model = bifrostResp.Response.Model + } else { + streamMessage.Model = bifrostResp.ExtraFields.ModelRequested + } + } + streamResp.Message = streamMessage + } + case schemas.ResponsesStreamResponseTypeInProgress: + // Skip converting response.in_progress back to avoid duplicate message_start events + // This is an OpenAI-style lifecycle event that doesn't map directly to Anthropic events + return nil + + case schemas.ResponsesStreamResponseTypeOutputItemAdded: + // Check if this is a computer tool call + if bifrostResp.Item != nil && + bifrostResp.Item.Type != nil && + *bifrostResp.Item.Type == schemas.ResponsesMessageTypeComputerCall { + + // Computer tool - emit content_block_start + streamResp.Type = AnthropicStreamEventTypeContentBlockStart + + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } else if bifrostResp.OutputIndex != nil { + streamResp.Index = bifrostResp.OutputIndex + } + + // Build the content_block as tool_use + // Note: Computer tool calls should not be converted to thinking blocks + contentBlock := &AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolUse, + ID: bifrostResp.Item.ID, // The tool use ID + Name: schemas.Ptr(string(AnthropicToolNameComputer)), // "computer" + } + + // Always start with empty input for streaming compatibility + contentBlock.Input = map[string]interface{}{} + + streamResp.ContentBlock = contentBlock + } else { + // Text or other content blocks - emit content_block_start + streamResp.Type = AnthropicStreamEventTypeContentBlockStart + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } else if bifrostResp.OutputIndex != nil { + streamResp.Index = bifrostResp.OutputIndex + } + + // Build content_block based on item type + if bifrostResp.Item != nil { + contentBlock := &AnthropicContentBlock{} + if bifrostResp.Item.Type != nil { + switch *bifrostResp.Item.Type { + case schemas.ResponsesMessageTypeMessage: + contentBlock.Type = AnthropicContentBlockTypeText + contentBlock.Text = schemas.Ptr("") + case schemas.ResponsesMessageTypeReasoning: + contentBlock.Type = AnthropicContentBlockTypeThinking + contentBlock.Thinking = schemas.Ptr("") + // Preserve signature if present + if bifrostResp.Item.ResponsesReasoning != nil && bifrostResp.Item.ResponsesReasoning.EncryptedContent != nil && *bifrostResp.Item.ResponsesReasoning.EncryptedContent != "" { + contentBlock.Data = bifrostResp.Item.ResponsesReasoning.EncryptedContent + // When signature is present but thinking content is empty, use redacted_thinking + if contentBlock.Thinking != nil && *contentBlock.Thinking == "" { + contentBlock.Type = AnthropicContentBlockTypeRedactedThinking + } + } + case schemas.ResponsesMessageTypeFunctionCall: + // Check if this item actually has reasoning content (misclassified) + // When thinking is enabled, reasoning content might be incorrectly classified as FunctionCall + if bifrostResp.Item.ResponsesReasoning != nil { + // This is actually reasoning content, not a function call + contentBlock.Type = AnthropicContentBlockTypeThinking + contentBlock.Thinking = schemas.Ptr("") + // Check if there's encrypted content for redacted_thinking + if bifrostResp.Item.ResponsesReasoning.EncryptedContent != nil && *bifrostResp.Item.ResponsesReasoning.EncryptedContent != "" { + contentBlock.Type = AnthropicContentBlockTypeRedactedThinking + contentBlock.Data = bifrostResp.Item.ResponsesReasoning.EncryptedContent + } + } else { + // Regular function call - check if ContentIndex is 0 and thinking might be enabled + // If ContentIndex is 0, we need to check if there's reasoning content in the response + contentIndex := 0 + if bifrostResp.ContentIndex != nil { + contentIndex = *bifrostResp.ContentIndex + } + isFirstBlock := contentIndex == 0 + + // Check if response has reasoning content (indicating thinking is enabled) + hasReasoningInResponse := false + if bifrostResp.Response != nil && bifrostResp.Response.Output != nil { + for _, msg := range bifrostResp.Response.Output { + if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeReasoning { + hasReasoningInResponse = true + break + } + } + } + + // When thinking is enabled and this is the first block, use thinking/redacted_thinking + if isFirstBlock && hasReasoningInResponse { + contentBlock.Type = AnthropicContentBlockTypeThinking + contentBlock.Thinking = schemas.Ptr("") + } else { + contentBlock.Type = AnthropicContentBlockTypeToolUse + if bifrostResp.Item.ResponsesToolMessage != nil { + contentBlock.ID = bifrostResp.Item.ResponsesToolMessage.CallID + contentBlock.Name = bifrostResp.Item.ResponsesToolMessage.Name + // Always start with empty input for streaming compatibility + contentBlock.Input = map[string]interface{}{} + } + } + } + case schemas.ResponsesMessageTypeMCPCall: + contentBlock.Type = AnthropicContentBlockTypeMCPToolUse + if bifrostResp.Item.ResponsesToolMessage != nil { + contentBlock.ID = bifrostResp.Item.ID + contentBlock.Name = bifrostResp.Item.ResponsesToolMessage.Name + if bifrostResp.Item.ResponsesToolMessage.ResponsesMCPToolCall != nil { + contentBlock.ServerName = &bifrostResp.Item.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel + } + // Always start with empty input for streaming compatibility + contentBlock.Input = map[string]interface{}{} + } + } + } + if contentBlock.Type != "" { + streamResp.ContentBlock = contentBlock + } + } + } + + // Generate synthetic input_json_delta events for tool calls with arguments + var events []*AnthropicStreamEvent + events = append(events, streamResp) + + // Check if this is a tool call with arguments that need to be streamed + if bifrostResp.Item != nil && bifrostResp.Item.ResponsesToolMessage != nil { + var argumentsJSON string + var shouldGenerateDeltas bool + + switch *bifrostResp.Item.Type { + case schemas.ResponsesMessageTypeFunctionCall: + if bifrostResp.Item.ResponsesToolMessage.Arguments != nil && *bifrostResp.Item.ResponsesToolMessage.Arguments != "" { + argumentsJSON = *bifrostResp.Item.ResponsesToolMessage.Arguments + shouldGenerateDeltas = true + } + case schemas.ResponsesMessageTypeMCPCall: + if bifrostResp.Item.ResponsesToolMessage.Arguments != nil && *bifrostResp.Item.ResponsesToolMessage.Arguments != "" { + argumentsJSON = *bifrostResp.Item.ResponsesToolMessage.Arguments + shouldGenerateDeltas = true + } + case schemas.ResponsesMessageTypeComputerCall: + if bifrostResp.Item.ResponsesToolMessage.Action != nil && bifrostResp.Item.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil { + actionInput := convertResponsesToAnthropicComputerAction(bifrostResp.Item.ResponsesToolMessage.Action.ResponsesComputerToolCallAction) + if jsonBytes, err := json.Marshal(actionInput); err == nil { + argumentsJSON = string(jsonBytes) + shouldGenerateDeltas = true + } + } + } + + if shouldGenerateDeltas && argumentsJSON != "" { + // Generate synthetic input_json_delta events by chunking the JSON + deltaEvents := generateSyntheticInputJSONDeltas(argumentsJSON, bifrostResp.ContentIndex) + events = append(events, deltaEvents...) + } + } + + return events + case schemas.ResponsesStreamResponseTypeContentPartAdded: + return nil + + case schemas.ResponsesStreamResponseTypeOutputTextDelta: + streamResp.Type = AnthropicStreamEventTypeContentBlockDelta + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } + if bifrostResp.Delta != nil { + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeText, + Text: bifrostResp.Delta, + } + } + + case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta: + streamResp.Type = AnthropicStreamEventTypeContentBlockDelta + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } + if bifrostResp.Arguments != nil { + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeInputJSON, + PartialJSON: bifrostResp.Arguments, + } + } else if bifrostResp.Delta != nil { + // Handle cases where Delta field is used instead of Arguments + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeInputJSON, + PartialJSON: bifrostResp.Delta, + } + } + + case schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta: + streamResp.Type = AnthropicStreamEventTypeContentBlockDelta + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } + + // Check if this is a signature delta or text delta + if bifrostResp.Signature != nil { + // This is a signature_delta + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeSignature, + Signature: bifrostResp.Signature, + } + } else if bifrostResp.Delta != nil { + // This is a thinking_delta + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeThinking, + Thinking: bifrostResp.Delta, + } + } + + case schemas.ResponsesStreamResponseTypeContentPartDone: + return nil + + case schemas.ResponsesStreamResponseTypeOutputItemDone: + if bifrostResp.Item != nil && + bifrostResp.Item.Type != nil && + *bifrostResp.Item.Type == schemas.ResponsesMessageTypeComputerCall { + + // Computer tool complete - emit content_block_delta with the action, then stop + // Note: We're sending the complete action JSON in one delta + streamResp.Type = AnthropicStreamEventTypeContentBlockDelta + + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } else if bifrostResp.OutputIndex != nil { + streamResp.Index = bifrostResp.OutputIndex + } + + // Convert the action to Anthropic format and marshal to JSON + if bifrostResp.Item.ResponsesToolMessage != nil && + bifrostResp.Item.ResponsesToolMessage.Action != nil && + bifrostResp.Item.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil { + + actionInput := convertResponsesToAnthropicComputerAction( + bifrostResp.Item.ResponsesToolMessage.Action.ResponsesComputerToolCallAction, + ) + + // Marshal the action to JSON string + if jsonBytes, err := json.Marshal(actionInput); err == nil { + jsonStr := string(jsonBytes) + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeInputJSON, + PartialJSON: &jsonStr, + } + } + } + } else { + // For text blocks and other content blocks, emit content_block_stop + streamResp.Type = AnthropicStreamEventTypeContentBlockStop + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } else if bifrostResp.OutputIndex != nil { + streamResp.Index = bifrostResp.OutputIndex + } + } + case schemas.ResponsesStreamResponseTypePing: + streamResp.Type = AnthropicStreamEventTypePing + + case schemas.ResponsesStreamResponseTypeCompleted: + streamResp.Type = AnthropicStreamEventTypeMessageStop + anthropicContentDeltaEvent := &AnthropicStreamEvent{ + Type: AnthropicStreamEventTypeMessageDelta, + } + if bifrostResp.Response.Usage != nil { + anthropicContentDeltaEvent.Usage = &AnthropicUsage{ + InputTokens: bifrostResp.Response.Usage.InputTokens, + OutputTokens: bifrostResp.Response.Usage.OutputTokens, + } + if bifrostResp.Response.Usage.InputTokensDetails != nil && bifrostResp.Response.Usage.InputTokensDetails.CachedTokens > 0 { + anthropicContentDeltaEvent.Usage.CacheReadInputTokens = bifrostResp.Response.Usage.InputTokensDetails.CachedTokens + } + if bifrostResp.Response.Usage.OutputTokensDetails != nil && bifrostResp.Response.Usage.OutputTokensDetails.CachedTokens > 0 { + anthropicContentDeltaEvent.Usage.CacheCreationInputTokens = bifrostResp.Response.Usage.OutputTokensDetails.CachedTokens + } + } + if bifrostResp.Response.StopReason != nil { + anthropicContentDeltaEvent.Delta = &AnthropicStreamDelta{ + StopReason: schemas.Ptr(ConvertBifrostFinishReasonToAnthropic(*bifrostResp.Response.StopReason)), + } + } + return []*AnthropicStreamEvent{anthropicContentDeltaEvent, streamResp} + + case schemas.ResponsesStreamResponseTypeMCPCallArgumentsDelta: + // MCP call arguments delta - convert to content_block_delta with input_json + streamResp.Type = AnthropicStreamEventTypeContentBlockDelta + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } else if bifrostResp.OutputIndex != nil { + streamResp.Index = bifrostResp.OutputIndex + } + if bifrostResp.Delta != nil { + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeInputJSON, + PartialJSON: bifrostResp.Delta, + } + } else if bifrostResp.Arguments != nil { + // Handle cases where Arguments field is used instead of Delta + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeInputJSON, + PartialJSON: bifrostResp.Arguments, + } + } + + case schemas.ResponsesStreamResponseTypeMCPCallCompleted: + // MCP call completed - emit content_block_stop + streamResp.Type = AnthropicStreamEventTypeContentBlockStop + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } else if bifrostResp.OutputIndex != nil { + streamResp.Index = bifrostResp.OutputIndex + } + + case schemas.ResponsesStreamResponseTypeMCPCallFailed: + // MCP call failed - emit error event + streamResp.Type = AnthropicStreamEventTypeError + errorMsg := "MCP call failed" + if bifrostResp.Message != nil { + errorMsg = *bifrostResp.Message + } + streamResp.Error = &AnthropicStreamError{ + Type: "error", + Message: errorMsg, + } + + case "message_delta": + // Check if integration type in ctx is anthropic + if ctx.Value(schemas.BifrostContextKeyIntegrationType) == "anthropic" { + streamResp.Type = AnthropicStreamEventTypeMessageDelta + + // Convert usage from Bifrost format to Anthropic format + if bifrostResp.Response != nil && bifrostResp.Response.Usage != nil { + streamResp.Usage = &AnthropicUsage{ + InputTokens: bifrostResp.Response.Usage.InputTokens, + OutputTokens: bifrostResp.Response.Usage.OutputTokens, + } + if bifrostResp.Response.Usage.InputTokensDetails != nil && bifrostResp.Response.Usage.InputTokensDetails.CachedTokens > 0 { + streamResp.Usage.CacheReadInputTokens = bifrostResp.Response.Usage.InputTokensDetails.CachedTokens + } + if bifrostResp.Response.Usage.OutputTokensDetails != nil && bifrostResp.Response.Usage.OutputTokensDetails.CachedTokens > 0 { + streamResp.Usage.CacheCreationInputTokens = bifrostResp.Response.Usage.OutputTokensDetails.CachedTokens + } + } + + // Convert stop reason from Bifrost format to Anthropic format + if bifrostResp.Response != nil && bifrostResp.Response.StopReason != nil { + streamResp.Delta = &AnthropicStreamDelta{ + StopReason: schemas.Ptr(ConvertBifrostFinishReasonToAnthropic(*bifrostResp.Response.StopReason)), + } + } else if bifrostResp.Delta != nil { + // Handle text delta if present + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeText, + Text: bifrostResp.Delta, + } + } + } + + case schemas.ResponsesStreamResponseTypeError: + streamResp.Type = AnthropicStreamEventTypeError + if bifrostResp.Message != nil { + streamResp.Error = &AnthropicStreamError{ + Type: "error", + Message: *bifrostResp.Message, + } + } + + default: + // Unknown event type, return empty + return nil + } + + return []*AnthropicStreamEvent{streamResp} +} + +// ToBifrostResponsesRequest converts an Anthropic message request to Bifrost format +func (request *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx context.Context) *schemas.BifrostResponsesRequest { + provider, model := schemas.ParseModelString(request.Model, schemas.Anthropic) + + bifrostReq := &schemas.BifrostResponsesRequest{ + Provider: provider, + Model: model, + Fallbacks: schemas.ParseFallbacks(request.Fallbacks), + } + + // Convert basic parameters + params := &schemas.ResponsesParameters{ + ExtraParams: make(map[string]interface{}), + } + + if request.MaxTokens > 0 { + params.MaxOutputTokens = &request.MaxTokens + } + if request.Temperature != nil { + params.Temperature = request.Temperature + } + if request.TopP != nil { + params.TopP = request.TopP + } + if request.Metadata != nil && request.Metadata.UserID != nil { + params.User = request.Metadata.UserID + } + if request.TopK != nil { + params.ExtraParams["top_k"] = *request.TopK + } + if request.StopSequences != nil { + params.ExtraParams["stop"] = request.StopSequences + } + if request.OutputFormat != nil { + params.Text = convertAnthropicOutputFormatToResponsesTextConfig(request.OutputFormat) + } + if request.Thinking != nil { + if request.Thinking.Type == "enabled" { + var summary *string + if summaryValue, ok := schemas.SafeExtractStringPointer(request.ExtraParams["reasoning_summary"]); ok { + summary = summaryValue + } + // check if user agent in ctx is claude-cli + if userAgent, ok := ctx.Value(schemas.BifrostContextKeyUserAgent).(string); ok { + if strings.Contains(userAgent, "claude-cli") { + summary = schemas.Ptr("detailed") + } + } + params.Reasoning = &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr(providerUtils.GetReasoningEffortFromBudgetTokens(*request.Thinking.BudgetTokens, MinimumReasoningMaxTokens, AnthropicDefaultMaxTokens)), + MaxTokens: request.Thinking.BudgetTokens, + Summary: summary, + } + } else { + params.Reasoning = &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr("none"), + } + } + } + if include, ok := schemas.SafeExtractStringSlice(request.ExtraParams["include"]); ok { + params.Include = include + } + + // Add trucation parameter if computer tool is being used + if provider == schemas.OpenAI && request.Tools != nil { + for _, tool := range request.Tools { + if tool.Type != nil && *tool.Type == AnthropicToolTypeComputer20250124 { + params.Truncation = schemas.Ptr("auto") + break + } + } + } + + bifrostReq.Params = params + + // Convert messages directly to ChatMessage format + var bifrostMessages []schemas.ResponsesMessage + + // Handle system message - convert Anthropic system field to first message with role "system" + if request.System != nil { + var systemText string + if request.System.ContentStr != nil { + systemText = *request.System.ContentStr + } else if request.System.ContentBlocks != nil { + // Combine text blocks from system content + var textParts []string + for _, block := range request.System.ContentBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + } + systemText = strings.Join(textParts, "\n") + } + + if systemText != "" { + systemMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &systemText, + }, + } + bifrostMessages = append(bifrostMessages, systemMsg) + } + } + + // Convert regular messages using the new conversion method + convertedMessages := ConvertAnthropicMessagesToBifrostMessages(request.Messages, nil, false, provider == schemas.Bedrock) + bifrostMessages = append(bifrostMessages, convertedMessages...) + + // Convert tools if present + if request.Tools != nil { + var bifrostTools []schemas.ResponsesTool + for _, tool := range request.Tools { + bifrostTool := convertAnthropicToolToBifrost(&tool) + if bifrostTool != nil { + bifrostTools = append(bifrostTools, *bifrostTool) + } + } + if len(bifrostTools) > 0 { + bifrostReq.Params.Tools = bifrostTools + } + } + + if request.MCPServers != nil { + var bifrostMCPTools []schemas.ResponsesTool + for _, mcpServer := range request.MCPServers { + bifrostMCPTool := convertAnthropicMCPServerToBifrostTool(&mcpServer) + if bifrostMCPTool != nil { + bifrostMCPTools = append(bifrostMCPTools, *bifrostMCPTool) + } + } + if len(bifrostMCPTools) > 0 { + bifrostReq.Params.Tools = append(bifrostReq.Params.Tools, bifrostMCPTools...) + } + } + + // Convert tool choice if present + if request.ToolChoice != nil { + bifrostToolChoice := convertAnthropicToolChoiceToBifrost(request.ToolChoice) + if bifrostToolChoice != nil { + bifrostReq.Params.ToolChoice = bifrostToolChoice + } + } + + // Set the converted messages + if len(bifrostMessages) > 0 { + bifrostReq.Input = bifrostMessages + } + + return bifrostReq +} + +// ToAnthropicResponsesRequest converts a BifrostRequest with Responses structure back to AnthropicMessageRequest +func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*AnthropicMessageRequest, error) { + if bifrostReq == nil { + return nil, fmt.Errorf("bifrost request is nil") + } + + anthropicReq := &AnthropicMessageRequest{ + Model: bifrostReq.Model, + MaxTokens: AnthropicDefaultMaxTokens, + } + + // Convert basic parameters + if bifrostReq.Params != nil { + if bifrostReq.Params.MaxOutputTokens != nil { + anthropicReq.MaxTokens = *bifrostReq.Params.MaxOutputTokens } if bifrostReq.Params.Temperature != nil { anthropicReq.Temperature = bifrostReq.Params.Temperature @@ -323,23 +1446,29 @@ func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (* anthropicReq.OutputFormat = convertResponsesTextConfigToAnthropicOutputFormat(bifrostReq.Params.Text) } if bifrostReq.Params.Reasoning != nil { - if bifrostReq.Params.Reasoning.Effort != nil { - if *bifrostReq.Params.Reasoning.Effort != "none" { - if bifrostReq.Params.Reasoning.MaxTokens != nil { - if *bifrostReq.Params.Reasoning.MaxTokens < MinimumReasoningMaxTokens { - return nil, fmt.Errorf("reasoning.max_tokens must be greater than or equal to %d", MinimumReasoningMaxTokens) - } else { - anthropicReq.Thinking = &AnthropicThinking{ - Type: "enabled", - BudgetTokens: bifrostReq.Params.Reasoning.MaxTokens, - } + if bifrostReq.Params.Reasoning.MaxTokens != nil { + if *bifrostReq.Params.Reasoning.MaxTokens < MinimumReasoningMaxTokens { + return nil, fmt.Errorf("reasoning.max_tokens must be >= %d for anthropic", MinimumReasoningMaxTokens) + } + anthropicReq.Thinking = &AnthropicThinking{ + Type: "enabled", + BudgetTokens: bifrostReq.Params.Reasoning.MaxTokens, + } + } else { + if bifrostReq.Params.Reasoning.Effort != nil { + if *bifrostReq.Params.Reasoning.Effort != "none" { + budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, MinimumReasoningMaxTokens, anthropicReq.MaxTokens) + if err != nil { + return nil, err + } + anthropicReq.Thinking = &AnthropicThinking{ + Type: "enabled", + BudgetTokens: schemas.Ptr(budgetTokens), } } else { - return nil, fmt.Errorf("reasoning.max_tokens is required for reasoning") - } - } else { - anthropicReq.Thinking = &AnthropicThinking{ - Type: "disabled", + anthropicReq.Thinking = &AnthropicThinking{ + Type: "disabled", + } } } } @@ -390,7 +1519,7 @@ func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (* } if bifrostReq.Input != nil { - anthropicMessages, systemContent := convertResponsesMessagesToAnthropicMessages(bifrostReq.Input) + anthropicMessages, systemContent := ConvertBifrostMessagesToAnthropicMessages(bifrostReq.Input) // Set system message if present if systemContent != nil { @@ -439,10 +1568,19 @@ func (response *AnthropicMessageResponse) ToBifrostResponsesResponse() *schemas. } } - // Convert content to Responses output messages - outputMessages := convertAnthropicContentBlocksToResponsesMessages(response.Content) - if len(outputMessages) > 0 { - bifrostResp.Output = outputMessages + // Convert content to Responses output messages using the new conversion method + if len(response.Content) > 0 { + // Create a temporary message to use the conversion method + tempMsg := AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + Content: AnthropicContent{ + ContentBlocks: response.Content, + }, + } + outputMessages := ConvertAnthropicMessagesToBifrostMessages([]AnthropicMessage{tempMsg}, nil, true, false) + if len(outputMessages) > 0 { + bifrostResp.Output = outputMessages + } } bifrostResp.Model = response.Model @@ -475,10 +1613,21 @@ func ToAnthropicResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) } } - // Convert output messages to Anthropic content blocks + // Convert output messages to Anthropic content blocks using the new conversion method var contentBlocks []AnthropicContentBlock if bifrostResp.Output != nil { - contentBlocks = convertBifrostMessagesToAnthropicContent(bifrostResp.Output) + anthropicMessages, _ := ConvertBifrostMessagesToAnthropicMessages(bifrostResp.Output) + // Extract content blocks from the converted messages + for _, msg := range anthropicMessages { + if msg.Content.ContentBlocks != nil { + contentBlocks = append(contentBlocks, msg.Content.ContentBlocks...) + } else if msg.Content.ContentStr != nil { + contentBlocks = append(contentBlocks, AnthropicContentBlock{ + Type: AnthropicContentBlockTypeText, + Text: msg.Content.ContentStr, + }) + } + } } if len(contentBlocks) > 0 { @@ -503,1115 +1652,1352 @@ func ToAnthropicResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) return anthropicResp } -// ToBifrostResponsesStream converts an Anthropic stream event to a Bifrost Responses Stream response -// It maintains state via the state for handling multi-chunk conversions like computer tools -// Returns a slice of responses to support cases where a single event produces multiple responses -func (chunk *AnthropicStreamEvent) ToBifrostResponsesStream(ctx context.Context, sequenceNumber int, state *AnthropicResponsesStreamState) ([]*schemas.BifrostResponsesStreamResponse, *schemas.BifrostError, bool) { - switch chunk.Type { - case AnthropicStreamEventTypeMessageStart: - // Message start - emit response.created and response.in_progress (OpenAI-style lifecycle) - if chunk.Message != nil { - state.MessageID = &chunk.Message.ID - state.Model = &chunk.Message.Model - // Use the state's CreatedAt for consistency - if state.CreatedAt == 0 { - state.CreatedAt = int(time.Now().Unix()) - } - - var responses []*schemas.BifrostResponsesStreamResponse - - // Emit response.created - if !state.HasEmittedCreated { - response := &schemas.BifrostResponsesResponse{ - ID: state.MessageID, - CreatedAt: state.CreatedAt, - } - if state.Model != nil { - response.Model = *state.Model - } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeCreated, - SequenceNumber: sequenceNumber, - Response: response, - }) - state.HasEmittedCreated = true - } +// ConvertAnthropicMessagesToBifrostMessages converts an array of Anthropic messages to Bifrost ResponsesMessage format +func ConvertAnthropicMessagesToBifrostMessages(anthropicMessages []AnthropicMessage, systemContent *AnthropicContent, isOutputMessage bool, keepToolsGrouped bool) []schemas.ResponsesMessage { + var bifrostMessages []schemas.ResponsesMessage - // Emit response.in_progress - if !state.HasEmittedInProgress { - response := &schemas.BifrostResponsesResponse{ - ID: state.MessageID, - CreatedAt: state.CreatedAt, // Use same timestamp - } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeInProgress, - SequenceNumber: sequenceNumber + len(responses), - Response: response, - }) - state.HasEmittedInProgress = true - } + // Handle system message first if present + if systemContent != nil { + systemMessages := convertAnthropicSystemToBifrostMessages(systemContent) + bifrostMessages = append(bifrostMessages, systemMessages...) + } - if len(responses) > 0 { - return responses, nil, false - } + // Convert regular messages + for _, msg := range anthropicMessages { + var convertedMessages []schemas.ResponsesMessage + if keepToolsGrouped { + convertedMessages = convertSingleAnthropicMessageToBifrostMessagesGrouped(&msg, isOutputMessage) + } else { + convertedMessages = convertSingleAnthropicMessageToBifrostMessages(&msg, isOutputMessage) } + bifrostMessages = append(bifrostMessages, convertedMessages...) + } - case AnthropicStreamEventTypeContentBlockStart: - // Content block start - emit output_item.added (OpenAI-style) - if chunk.ContentBlock != nil && chunk.Index != nil { - outputIndex := state.getOrCreateOutputIndex(chunk.Index) - - if chunk.ContentBlock.Type == AnthropicContentBlockTypeToolUse && - chunk.ContentBlock.Name != nil && - *chunk.ContentBlock.Name == string(AnthropicToolNameComputer) && - chunk.ContentBlock.ID != nil { - - // Start accumulating computer tool - state.ComputerToolID = chunk.ContentBlock.ID - state.ChunkIndex = chunk.Index - state.AccumulatedJSON = "" - - // Emit output_item.added for computer_call - item := &schemas.ResponsesMessage{ - ID: chunk.ContentBlock.ID, - Type: schemas.Ptr(schemas.ResponsesMessageTypeComputerCall), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: chunk.ContentBlock.ID, - }, - } - - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Item: item, - }}, nil, false - } - - switch chunk.ContentBlock.Type { - case AnthropicContentBlockTypeText: - // Text block - emit output_item.added with type "message" - messageType := schemas.ResponsesMessageTypeMessage - role := schemas.ResponsesInputMessageRoleAssistant - - // Generate stable ID for text item - var itemID string - if state.MessageID == nil { - itemID = fmt.Sprintf("item_%d", outputIndex) - } else { - itemID = fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) - } - state.ItemIDs[outputIndex] = itemID - - item := &schemas.ResponsesMessage{ - ID: &itemID, - Type: &messageType, - Role: &role, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{}, // Empty blocks slice for mutation support - }, - } - - // Track that this content index is a text block - if chunk.Index != nil { - state.TextContentIndices[*chunk.Index] = true - } - - var responses []*schemas.BifrostResponsesStreamResponse - - // Emit output_item.added - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Item: item, - }) - - // Emit content_part.added with empty output_text part - emptyText := "" - part := &schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: &emptyText, - } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartAdded, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - ItemID: &itemID, - Part: part, - }) - - return responses, nil, false - - case AnthropicContentBlockTypeToolUse: - // Function call starting - emit output_item.added with type "function_call" and status "in_progress" - statusInProgress := "in_progress" - itemID := "" - if chunk.ContentBlock.ID != nil { - itemID = *chunk.ContentBlock.ID - state.ItemIDs[outputIndex] = itemID - } - item := &schemas.ResponsesMessage{ - ID: chunk.ContentBlock.ID, - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - Status: &statusInProgress, - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: chunk.ContentBlock.ID, - Name: chunk.ContentBlock.Name, - Arguments: schemas.Ptr(""), // Arguments will be filled by deltas - }, - } - - // Initialize argument buffer for this tool call - state.ToolArgumentBuffers[outputIndex] = "" - - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Item: item, - }}, nil, false - - case AnthropicContentBlockTypeMCPToolUse: - // MCP tool call starting - emit output_item.added - itemID := "" - if chunk.ContentBlock.ID != nil { - itemID = *chunk.ContentBlock.ID - state.ItemIDs[outputIndex] = itemID - } - item := &schemas.ResponsesMessage{ - ID: chunk.ContentBlock.ID, - Type: schemas.Ptr(schemas.ResponsesMessageTypeMCPCall), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - Name: chunk.ContentBlock.Name, - Arguments: schemas.Ptr(""), // Arguments will be filled by deltas - }, - } + return bifrostMessages +} - // Set server name if present - if chunk.ContentBlock.ServerName != nil { - item.ResponsesToolMessage.ResponsesMCPToolCall = &schemas.ResponsesMCPToolCall{ - ServerLabel: *chunk.ContentBlock.ServerName, - } - } +// ConvertBifrostMessagesToAnthropicMessages converts an array of Bifrost ResponsesMessage to Anthropic message format +// This is the main conversion method from Bifrost to Anthropic - handles all message types and returns messages + system content +func ConvertBifrostMessagesToAnthropicMessages(bifrostMessages []schemas.ResponsesMessage) ([]AnthropicMessage, *AnthropicContent) { + var anthropicMessages []AnthropicMessage + var systemContent *AnthropicContent + var pendingToolCalls []AnthropicContentBlock + var pendingToolResultBlocks []AnthropicContentBlock + var pendingReasoningContentBlocks []AnthropicContentBlock + var currentAssistantMessage *AnthropicMessage - // Initialize argument buffer for this MCP call and mark as MCP - state.ToolArgumentBuffers[outputIndex] = "" - state.MCPCallOutputIndices[outputIndex] = true + // Track tool call IDs for each assistant turn to properly match tool results + // Each assistant turn that contains tool_use blocks should have its tool results + // grouped in a corresponding user message + type toolCallGroup struct { + toolCallIDs map[string]bool // Set of tool call IDs in this group + flushed bool // Whether the tool results for this group have been flushed + } + var toolCallGroups []toolCallGroup + var currentToolCallIDs map[string]bool // IDs of tool calls in the current pending batch + + // Helper to flush pending tool result blocks into user messages + // This now matches tool results to their corresponding tool call groups + flushPendingToolResults := func() { + if len(pendingToolResultBlocks) == 0 { + return + } + + // If there are no tool call groups, just flush all results together + if len(toolCallGroups) == 0 { + anthropicMessages = append(anthropicMessages, AnthropicMessage{ + Role: AnthropicMessageRoleUser, + Content: AnthropicContent{ + ContentBlocks: pendingToolResultBlocks, + }, + }) + pendingToolResultBlocks = nil + return + } - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Item: item, - }}, nil, false + // Group tool results by their corresponding tool call group + // Each group should be flushed as a separate user message + for i := range toolCallGroups { + if toolCallGroups[i].flushed { + continue + } - case AnthropicContentBlockTypeThinking: - // Thinking/reasoning block - emit output_item.added with type "reasoning" - messageType := schemas.ResponsesMessageTypeReasoning - role := schemas.ResponsesInputMessageRoleAssistant + var groupResults []AnthropicContentBlock + var remainingResults []AnthropicContentBlock - // Generate stable ID for reasoning item - var itemID string - if state.MessageID == nil { - itemID = fmt.Sprintf("reasoning_%d", outputIndex) + for _, block := range pendingToolResultBlocks { + if block.ToolUseID != nil && toolCallGroups[i].toolCallIDs[*block.ToolUseID] { + groupResults = append(groupResults, block) } else { - itemID = fmt.Sprintf("msg_%s_reasoning_%d", *state.MessageID, outputIndex) + remainingResults = append(remainingResults, block) } - state.ItemIDs[outputIndex] = itemID + } - // Initialize reasoning structure - item := &schemas.ResponsesMessage{ - ID: &itemID, - Type: &messageType, - Role: &role, - ResponsesReasoning: &schemas.ResponsesReasoning{ - Summary: []schemas.ResponsesReasoningContent{}, + if len(groupResults) > 0 { + anthropicMessages = append(anthropicMessages, AnthropicMessage{ + Role: AnthropicMessageRoleUser, + Content: AnthropicContent{ + ContentBlocks: groupResults, }, - } + }) + toolCallGroups[i].flushed = true + pendingToolResultBlocks = remainingResults + } + } - // Preserve signature if present - if chunk.ContentBlock.Signature != nil { - item.ResponsesReasoning.EncryptedContent = chunk.ContentBlock.Signature - } + // Flush any remaining tool results that didn't match any group + if len(pendingToolResultBlocks) > 0 { + anthropicMessages = append(anthropicMessages, AnthropicMessage{ + Role: AnthropicMessageRoleUser, + Content: AnthropicContent{ + ContentBlocks: pendingToolResultBlocks, + }, + }) + pendingToolResultBlocks = nil + } + } - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Item: item, - }}, nil, false + // Helper to flush pending tool calls with tool call ID tracking + flushPendingToolCallsWithTracking := func() { + if len(pendingToolCalls) > 0 && currentAssistantMessage != nil { + // Copy the slice to avoid aliasing issues + copied := make([]AnthropicContentBlock, len(pendingToolCalls)) + copy(copied, pendingToolCalls) + currentAssistantMessage.Content = AnthropicContent{ + ContentBlocks: copied, } - } + anthropicMessages = append(anthropicMessages, *currentAssistantMessage) - case AnthropicStreamEventTypeContentBlockDelta: - if chunk.Index != nil && chunk.Delta != nil { - outputIndex := state.getOrCreateOutputIndex(chunk.Index) + // Record this tool call group for matching with tool results + if len(currentToolCallIDs) > 0 { + toolCallGroups = append(toolCallGroups, toolCallGroup{ + toolCallIDs: currentToolCallIDs, + flushed: false, + }) + currentToolCallIDs = nil + } - // Handle different delta types - switch chunk.Delta.Type { - case AnthropicStreamDeltaTypeText: - if chunk.Delta.Text != nil && *chunk.Delta.Text != "" { - // Text content delta - emit output_text.delta with item ID - itemID := state.ItemIDs[outputIndex] - response := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Delta: chunk.Delta.Text, - } - if itemID != "" { - response.ItemID = &itemID - } - return []*schemas.BifrostResponsesStreamResponse{response}, nil, false - } + pendingToolCalls = nil + currentAssistantMessage = nil + } + } - case AnthropicStreamDeltaTypeInputJSON: - // Function call arguments delta - if chunk.Delta.PartialJSON != nil && *chunk.Delta.PartialJSON != "" { - // Check if we're accumulating a computer tool - if state.ComputerToolID != nil && - state.ChunkIndex != nil && - *state.ChunkIndex == *chunk.Index { - // Accumulate the JSON and don't emit anything - state.AccumulatedJSON += *chunk.Delta.PartialJSON - return nil, nil, false - } + for _, msg := range bifrostMessages { + // Handle nil Type as regular message + msgType := schemas.ResponsesMessageTypeMessage + if msg.Type != nil { + msgType = *msg.Type + } - // Accumulate tool arguments in buffer - if _, exists := state.ToolArgumentBuffers[outputIndex]; !exists { - state.ToolArgumentBuffers[outputIndex] = "" - } - state.ToolArgumentBuffers[outputIndex] += *chunk.Delta.PartialJSON + switch msgType { + case schemas.ResponsesMessageTypeMessage: + // Flush any pending tool results before processing other message types + flushPendingToolResults() - // Emit appropriate delta type based on whether this is an MCP call - var deltaType schemas.ResponsesStreamResponseType - if state.MCPCallOutputIndices[outputIndex] { - deltaType = schemas.ResponsesStreamResponseTypeMCPCallArgumentsDelta - } else { - deltaType = schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta - } + // Flush any pending tool calls first (with tracking for tool call groups) + flushPendingToolCallsWithTracking() - itemID := state.ItemIDs[outputIndex] - response := &schemas.BifrostResponsesStreamResponse{ - Type: deltaType, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Delta: chunk.Delta.PartialJSON, - } - if itemID != "" { - response.ItemID = &itemID - } - return []*schemas.BifrostResponsesStreamResponse{response}, nil, false - } + // Handle system messages separately + if msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleSystem { + systemContent = convertBifrostMessageToAnthropicSystemContent(&msg) + continue + } - case AnthropicStreamDeltaTypeThinking: - // Reasoning/thinking content delta - if chunk.Delta.Thinking != nil && *chunk.Delta.Thinking != "" { - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Delta: chunk.Delta.Thinking, - }}, nil, false + // If there are pending reasoning blocks and this is a user message, + // flush them into a separate assistant message first + // (thinking blocks can only appear in assistant messages in Anthropic) + if len(pendingReasoningContentBlocks) > 0 && (msg.Role == nil || *msg.Role == schemas.ResponsesInputMessageRoleUser) { + // Copy the pending reasoning content blocks + copied := make([]AnthropicContentBlock, len(pendingReasoningContentBlocks)) + copy(copied, pendingReasoningContentBlocks) + assistantReasoningMsg := AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + Content: AnthropicContent{ + ContentBlocks: copied, + }, } + anthropicMessages = append(anthropicMessages, assistantReasoningMsg) + pendingReasoningContentBlocks = nil + } - case AnthropicStreamDeltaTypeSignature: - // Handle signature verification for thinking content - // Store the signature in state for the reasoning item - if chunk.Delta.Signature != nil && *chunk.Delta.Signature != "" { - state.ReasoningSignatures[outputIndex] = *chunk.Delta.Signature - // Emit signature_delta event to pass through - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta, // Reuse this type for signature - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Delta: chunk.Delta.Signature, - }}, nil, false - } - return nil, nil, false + // Regular user/assistant message + anthropicMsg := convertBifrostMessageToAnthropicMessage(&msg, &pendingReasoningContentBlocks) + if anthropicMsg != nil { + anthropicMessages = append(anthropicMessages, *anthropicMsg) } - } - case AnthropicStreamEventTypeContentBlockStop: - // Content block is complete - emit output_item.done (OpenAI-style) - if chunk.Index != nil { - outputIndex := state.getOrCreateOutputIndex(chunk.Index) + case schemas.ResponsesMessageTypeReasoning: + // Flush any pending tool results before processing reasoning + flushPendingToolResults() - // Check if this is the end of a computer tool accumulation - if state.ComputerToolID != nil && - state.ChunkIndex != nil && - *state.ChunkIndex == *chunk.Index { + // Handle reasoning as thinking content + reasoningBlocks := convertBifrostReasoningToAnthropicThinking(&msg) + pendingReasoningContentBlocks = append(pendingReasoningContentBlocks, reasoningBlocks...) - // Parse accumulated JSON and convert to OpenAI format - var inputMap map[string]interface{} - var action *schemas.ResponsesComputerToolCallAction + case schemas.ResponsesMessageTypeFunctionCall: + // Flush any pending tool results before processing function calls + flushPendingToolResults() - if state.AccumulatedJSON != "" { - if err := json.Unmarshal([]byte(state.AccumulatedJSON), &inputMap); err == nil { - action = convertAnthropicToResponsesComputerAction(inputMap) - } + // When thinking blocks exist, they MUST come first before tool_use blocks + // If we have pending reasoning blocks, we need to prepend them to the assistant message + if currentAssistantMessage == nil { + currentAssistantMessage = &AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, } + } - // Create computer_call item with action - statusCompleted := "completed" - item := &schemas.ResponsesMessage{ - ID: state.ComputerToolID, - Type: schemas.Ptr(schemas.ResponsesMessageTypeComputerCall), - Status: &statusCompleted, - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: state.ComputerToolID, - ResponsesComputerToolCall: &schemas.ResponsesComputerToolCall{ - PendingSafetyChecks: []schemas.ResponsesComputerToolCallPendingSafetyCheck{}, - }, - }, + // Prepend any pending reasoning blocks to ensure they come BEFORE tool_use blocks + // This is required by Anthropic/Bedrock API: if an assistant message contains thinking blocks, + // the first block must be thinking or redacted_thinking, NOT tool_use + if len(pendingReasoningContentBlocks) > 0 { + copied := make([]AnthropicContentBlock, len(pendingReasoningContentBlocks)) + copy(copied, pendingReasoningContentBlocks) + pendingToolCalls = append(copied, pendingToolCalls...) + pendingReasoningContentBlocks = nil + } + + toolUseBlock := convertBifrostFunctionCallToAnthropicToolUse(&msg) + if toolUseBlock != nil { + // If there was a previous assistant message (text only) that was just added, + // and we have no pending tool calls yet, we should merge the tool call into it. + // This handles the case where an assistant text message precedes tool calls. + if len(pendingToolCalls) == 0 && len(anthropicMessages) > 0 { + lastMsgIdx := len(anthropicMessages) - 1 + lastMsg := &anthropicMessages[lastMsgIdx] + + // Check if the last message is an assistant message that could have text + if lastMsg.Role == AnthropicMessageRoleAssistant { + hasToolUse := false + for _, block := range lastMsg.Content.ContentBlocks { + if block.Type == AnthropicContentBlockTypeToolUse { + hasToolUse = true + break + } + } + // If the last assistant message has no tool_use blocks, merge the tool call into it + if !hasToolUse { + // Copy existing content blocks and append the tool_use + existingBlocks := lastMsg.Content.ContentBlocks + existingBlocks = append(existingBlocks, *toolUseBlock) + lastMsg.Content = AnthropicContent{ + ContentBlocks: existingBlocks, + } + // Track the tool call ID + if currentToolCallIDs == nil { + currentToolCallIDs = make(map[string]bool) + } + if toolUseBlock.ID != nil { + currentToolCallIDs[*toolUseBlock.ID] = true + } + // Use this message as the current one for subsequent tool calls + pendingToolCalls = lastMsg.Content.ContentBlocks + anthropicMessages = anthropicMessages[:lastMsgIdx] // Remove it, will be re-added on flush + currentAssistantMessage = lastMsg + continue + } + } } - // Add action if we successfully parsed it - if action != nil { - item.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ - ResponsesComputerToolCallAction: action, - } + pendingToolCalls = append(pendingToolCalls, *toolUseBlock) + + // Track the tool call ID for matching with tool results + if currentToolCallIDs == nil { + currentToolCallIDs = make(map[string]bool) } + if toolUseBlock.ID != nil { + currentToolCallIDs[*toolUseBlock.ID] = true + } + } - // Clear computer tool state - state.ComputerToolID = nil - state.ChunkIndex = nil - state.AccumulatedJSON = "" + case schemas.ResponsesMessageTypeFunctionCallOutput: + // Flush any pending tool calls first before processing tool results (with tracking) + flushPendingToolCallsWithTracking() - // Return output_item.done - return []*schemas.BifrostResponsesStreamResponse{ - { - Type: schemas.ResponsesStreamResponseTypeOutputItemDone, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Item: item, - }, - }, nil, false + // Accumulate tool result blocks - they will be merged into a single user message + // This is required because Anthropic/Bedrock expect all tool results for parallel + // tool calls to be in the same user message, in the same order as the tool calls + toolResultBlock := convertBifrostFunctionCallOutputToAnthropicToolResultBlock(&msg) + if toolResultBlock != nil { + pendingToolResultBlocks = append(pendingToolResultBlocks, *toolResultBlock) + } + + case schemas.ResponsesMessageTypeItemReference: + // Flush any pending tool results before processing item reference + flushPendingToolResults() + + // Handle item reference as regular text message + referenceMsg := convertBifrostItemReferenceToAnthropicMessage(&msg) + if referenceMsg != nil { + anthropicMessages = append(anthropicMessages, *referenceMsg) + } + + case schemas.ResponsesMessageTypeComputerCall: + // Flush any pending tool results before processing computer calls + flushPendingToolResults() + + // Start accumulating computer tool calls for assistant message + if currentAssistantMessage == nil { + currentAssistantMessage = &AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + } } - // Check if this is a text block - emit output_text.done and content_part.done - var responses []*schemas.BifrostResponsesStreamResponse - itemID := state.ItemIDs[outputIndex] + // Prepend any pending reasoning blocks to ensure they come BEFORE tool_use blocks + if len(pendingReasoningContentBlocks) > 0 { + copied := make([]AnthropicContentBlock, len(pendingReasoningContentBlocks)) + copy(copied, pendingReasoningContentBlocks) + pendingToolCalls = append(copied, pendingToolCalls...) + pendingReasoningContentBlocks = nil + } - // Check if this content index is a text block - if chunk.Index != nil { - if state.TextContentIndices[*chunk.Index] { - // Emit output_text.done (without accumulated text, just the event) - emptyText := "" - textDoneResponse := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Text: &emptyText, + computerToolUseBlock := convertBifrostComputerCallToAnthropicToolUse(&msg) + if computerToolUseBlock != nil { + pendingToolCalls = append(pendingToolCalls, *computerToolUseBlock) + + // Track the tool call ID for matching with tool results + if currentToolCallIDs == nil { + currentToolCallIDs = make(map[string]bool) + } + if computerToolUseBlock.ID != nil { + currentToolCallIDs[*computerToolUseBlock.ID] = true + } + } + + case schemas.ResponsesMessageTypeMCPCall: + // Check if this is a tool use (from assistant) or tool result (from user) + if msg.ResponsesToolMessage != nil { + if msg.ResponsesToolMessage.Name != nil { + // Flush any pending tool results before processing MCP calls + flushPendingToolResults() + + // This is a tool use call (assistant calling a tool) + if currentAssistantMessage == nil { + currentAssistantMessage = &AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + } } - if itemID != "" { - textDoneResponse.ItemID = &itemID + + // Prepend any pending reasoning blocks to ensure they come BEFORE tool_use blocks + if len(pendingReasoningContentBlocks) > 0 { + copied := make([]AnthropicContentBlock, len(pendingReasoningContentBlocks)) + copy(copied, pendingReasoningContentBlocks) + pendingToolCalls = append(copied, pendingToolCalls...) + pendingReasoningContentBlocks = nil } - responses = append(responses, textDoneResponse) - // Emit content_part.done - partDoneResponse := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, + mcpToolUseBlock := convertBifrostMCPCallToAnthropicToolUse(&msg) + if mcpToolUseBlock != nil { + pendingToolCalls = append(pendingToolCalls, *mcpToolUseBlock) + + // Track the tool call ID for matching with tool results + if currentToolCallIDs == nil { + currentToolCallIDs = make(map[string]bool) + } + if mcpToolUseBlock.ID != nil { + currentToolCallIDs[*mcpToolUseBlock.ID] = true + } } - if itemID != "" { - partDoneResponse.ItemID = &itemID + } else if msg.ResponsesToolMessage.CallID != nil { + // This is a tool result (user providing result of tool execution) + // Accumulate with other tool results + mcpToolResultBlock := convertBifrostMCPCallOutputToAnthropicToolResultBlock(&msg) + if mcpToolResultBlock != nil { + pendingToolResultBlocks = append(pendingToolResultBlocks, *mcpToolResultBlock) } - responses = append(responses, partDoneResponse) - - // Clear the text content index tracking - delete(state.TextContentIndices, *chunk.Index) } } - // Check if this is a tool call (function_call or MCP call) - // If we have accumulated arguments, emit appropriate arguments.done first - if accumulatedArgs, hasArgs := state.ToolArgumentBuffers[outputIndex]; hasArgs && accumulatedArgs != "" { - // Emit appropriate arguments.done based on whether this is an MCP call - var doneType schemas.ResponsesStreamResponseType - if state.MCPCallOutputIndices[outputIndex] { - doneType = schemas.ResponsesStreamResponseTypeMCPCallArgumentsDone - } else { - doneType = schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone - } + case schemas.ResponsesMessageTypeMCPApprovalRequest: + // Flush any pending tool results before processing MCP approval requests + flushPendingToolResults() - response := &schemas.BifrostResponsesStreamResponse{ - Type: doneType, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Arguments: &accumulatedArgs, - } - if itemID != "" { - response.ItemID = &itemID + // MCP approval request is OpenAI-specific for human-in-the-loop workflows + // Convert to Anthropic's mcp_tool_use format (same as regular MCP calls) + if currentAssistantMessage == nil { + currentAssistantMessage = &AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, } - responses = append(responses, response) - // Clear the buffer and MCP tracking - delete(state.ToolArgumentBuffers, outputIndex) - delete(state.MCPCallOutputIndices, outputIndex) } - // Check if this is a reasoning item and we have a signature - // If so, include the signature in the reasoning item - if signature, hasSignature := state.ReasoningSignatures[outputIndex]; hasSignature && signature != "" { - itemID := state.ItemIDs[outputIndex] - // Find if we have a reasoning item in responses or create one - var reasoningItem *schemas.ResponsesMessage - for _, resp := range responses { - if resp.Item != nil && resp.Item.Type != nil && *resp.Item.Type == schemas.ResponsesMessageTypeReasoning { - reasoningItem = resp.Item - break - } - } - if reasoningItem == nil { - reasoningItem = &schemas.ResponsesMessage{ - ID: &itemID, - Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), - } + // Prepend any pending reasoning blocks to ensure they come BEFORE tool_use blocks + if len(pendingReasoningContentBlocks) > 0 { + copied := make([]AnthropicContentBlock, len(pendingReasoningContentBlocks)) + copy(copied, pendingReasoningContentBlocks) + pendingToolCalls = append(copied, pendingToolCalls...) + pendingReasoningContentBlocks = nil + } + + mcpApprovalBlock := convertBifrostMCPApprovalToAnthropicToolUse(&msg) + if mcpApprovalBlock != nil { + pendingToolCalls = append(pendingToolCalls, *mcpApprovalBlock) + + // Track the tool call ID for matching with tool results + if currentToolCallIDs == nil { + currentToolCallIDs = make(map[string]bool) } - if reasoningItem.ResponsesReasoning == nil { - reasoningItem.ResponsesReasoning = &schemas.ResponsesReasoning{} + if mcpApprovalBlock.ID != nil { + currentToolCallIDs[*mcpApprovalBlock.ID] = true } - reasoningItem.ResponsesReasoning.EncryptedContent = &signature } - // Emit output_item.done for all content blocks (text, tool, etc.) - statusCompleted := "completed" - doneItemID := state.ItemIDs[outputIndex] - doneItem := &schemas.ResponsesMessage{ - Status: &statusCompleted, + // Handle other tool call types that are not natively supported by Anthropic + case schemas.ResponsesMessageTypeFileSearchCall, + schemas.ResponsesMessageTypeCodeInterpreterCall, + schemas.ResponsesMessageTypeWebSearchCall, + schemas.ResponsesMessageTypeLocalShellCall, + schemas.ResponsesMessageTypeCustomToolCall, + schemas.ResponsesMessageTypeImageGenerationCall: + // Flush any pending tool results before processing unsupported tool calls + flushPendingToolResults() + + // Convert unsupported tool calls to regular text messages + unsupportedToolMsg := convertBifrostUnsupportedToolCallToAnthropicMessage(&msg, msgType) + if unsupportedToolMsg != nil { + anthropicMessages = append(anthropicMessages, *unsupportedToolMsg) } - if doneItemID != "" { - doneItem.ID = &doneItemID + + case schemas.ResponsesMessageTypeComputerCallOutput: + // Flush any pending tool calls first before processing tool results (with tracking) + flushPendingToolCallsWithTracking() + + // Accumulate computer call output with other tool results + computerResultBlock := convertBifrostComputerCallOutputToAnthropicToolResultBlock(&msg) + if computerResultBlock != nil { + pendingToolResultBlocks = append(pendingToolResultBlocks, *computerResultBlock) } - // Include signature if this is a reasoning item - if signature, hasSignature := state.ReasoningSignatures[outputIndex]; hasSignature && signature != "" { - if doneItem.ResponsesReasoning == nil { - doneItem.ResponsesReasoning = &schemas.ResponsesReasoning{} - } - doneItem.ResponsesReasoning.EncryptedContent = &signature + + case schemas.ResponsesMessageTypeLocalShellCallOutput, + schemas.ResponsesMessageTypeCustomToolCallOutput: + // Handle tool outputs as user messages + toolOutputMsg := convertBifrostToolOutputToAnthropicMessage(&msg) + if toolOutputMsg != nil { + anthropicMessages = append(anthropicMessages, *toolOutputMsg) } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Item: doneItem, - }) - return responses, nil, false + default: + // Skip unknown message types or log them for debugging + continue } + } - case AnthropicStreamEventTypeMessageDelta: - isAnthropicPassthrough, ok := ctx.Value(schemas.BifrostContextKey("is_anthropic_passthrough")).(bool) - if ok && isAnthropicPassthrough { - // Message-level updates (like stop reason, usage, etc.) - // For Anthropic passthrough mode, we should forward these events as they contain - // important information like stop_reason and final usage counts - - // Create a message_delta event that will be passed through in raw mode - // Since there's no specific BifrostResponsesStreamResponse type for message deltas, - // we'll use a custom approach that allows the integration layer to pass it through - response := &schemas.BifrostResponsesResponse{ - CreatedAt: state.CreatedAt, + // Flush any remaining pending tool results + flushPendingToolResults() + + // Flush any remaining pending tool calls (with tracking) + flushPendingToolCallsWithTracking() + + return anthropicMessages, systemContent +} + +// Helper function to convert Anthropic system content to Bifrost messages +func convertAnthropicSystemToBifrostMessages(systemContent *AnthropicContent) []schemas.ResponsesMessage { + var systemText string + if systemContent.ContentStr != nil { + systemText = *systemContent.ContentStr + } else if systemContent.ContentBlocks != nil { + // Combine text blocks from system content + var textParts []string + for _, block := range systemContent.ContentBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) } - if state.MessageID != nil { - response.ID = state.MessageID + } + systemText = strings.Join(textParts, "\n") + } + + if systemText != "" { + return []schemas.ResponsesMessage{{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &systemText, + }, + }} + } + return []schemas.ResponsesMessage{} +} + +// Helper function to convert a single Anthropic message to Bifrost messages +func convertSingleAnthropicMessageToBifrostMessages(msg *AnthropicMessage, isOutputMessage bool) []schemas.ResponsesMessage { + // Handle text content (simple case) + if msg.Content.ContentStr != nil { + roleVal := schemas.ResponsesMessageRoleType(msg.Role) + return []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: &roleVal, + Content: &schemas.ResponsesMessageContent{ + ContentStr: msg.Content.ContentStr, + }, + }, + } + } + + // Handle content blocks + if msg.Content.ContentBlocks != nil { + roleVal := schemas.ResponsesMessageRoleType(msg.Role) + return convertAnthropicContentBlocksToResponsesMessages(msg.Content.ContentBlocks, &roleVal, isOutputMessage) + } + + return []schemas.ResponsesMessage{} +} + +// Helper function to convert a single Anthropic message to Bifrost messages, grouping text and tool calls +// This keeps assistant messages with mixed text and tool_use blocks together +func convertSingleAnthropicMessageToBifrostMessagesGrouped(msg *AnthropicMessage, isOutputMessage bool) []schemas.ResponsesMessage { + // Handle text content (simple case) + if msg.Content.ContentStr != nil { + roleVal := schemas.ResponsesMessageRoleType(msg.Role) + return []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: &roleVal, + Content: &schemas.ResponsesMessageContent{ + ContentStr: msg.Content.ContentStr, + }, + }, + } + } + + // Handle content blocks with grouping for text and tool calls + if msg.Content.ContentBlocks != nil { + roleVal := schemas.ResponsesMessageRoleType(msg.Role) + return convertAnthropicContentBlocksToResponsesMessagesGrouped(msg.Content.ContentBlocks, &roleVal, isOutputMessage) + } + + return []schemas.ResponsesMessage{} +} + +// Helper function to convert Anthropic content blocks to Bifrost ResponsesMessages, grouping text and tool_use blocks +func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []AnthropicContentBlock, role *schemas.ResponsesMessageRoleType, isOutputMessage bool) []schemas.ResponsesMessage { + var bifrostMessages []schemas.ResponsesMessage + var reasoningContentBlocks []schemas.ResponsesMessageContentBlock + var accumulatedTextContent []schemas.ResponsesMessageContentBlock + var pendingToolUseBlocks []*AnthropicContentBlock // Accumulate tool_use blocks + + // Process content blocks + for _, block := range contentBlocks { + switch block.Type { + case AnthropicContentBlockTypeText: + if block.Text != nil { + if isOutputMessage { + // For output messages, accumulate text blocks (don't emit immediately) + accumulatedTextContent = append(accumulatedTextContent, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: block.Text, + }) + } else { + // For input messages, emit text immediately as separate message + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: role, + Content: &schemas.ResponsesMessageContent{ + ContentStr: block.Text, + }, + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } } - // Include usage information if present - if chunk.Usage != nil { - response.Usage = &schemas.ResponsesResponseUsage{ - InputTokens: chunk.Usage.InputTokens, - OutputTokens: chunk.Usage.OutputTokens, - TotalTokens: chunk.Usage.InputTokens + chunk.Usage.OutputTokens, + case AnthropicContentBlockTypeImage: + // Don't emit accumulated text or tool_use blocks for images + if block.Source != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{block.toBifrostResponsesImageBlock()}, + }, } - if chunk.Usage.CacheReadInputTokens > 0 { - if response.Usage.InputTokensDetails == nil { - response.Usage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{} - } - response.Usage.InputTokensDetails.CachedTokens = chunk.Usage.CacheReadInputTokens + if isOutputMessage { + bifrostMsg.ID = schemas.Ptr("msg_" + utils.GetRandomString(50)) } - if chunk.Usage.CacheCreationInputTokens > 0 { - if response.Usage.OutputTokensDetails == nil { - response.Usage.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{} + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + + case AnthropicContentBlockTypeThinking: + if block.Thinking != nil { + // Collect reasoning blocks without flushing accumulated text/tool blocks + reasoningContentBlocks = append(reasoningContentBlocks, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: block.Thinking, + Signature: block.Signature, + }) + } + + case AnthropicContentBlockTypeRedactedThinking: + // Handle redacted thinking (encrypted content) + if block.Data != nil { + bifrostMsg := schemas.ResponsesMessage{ + ID: schemas.Ptr("rs_" + utils.GetRandomString(50)), + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + EncryptedContent: block.Data, + }, + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + + case AnthropicContentBlockTypeToolUse: + // Accumulate tool_use blocks to group them together + if block.ID != nil && block.Name != nil { + blockCopy := block + pendingToolUseBlocks = append(pendingToolUseBlocks, &blockCopy) + } + + case AnthropicContentBlockTypeToolResult: + // Convert tool result to function call output message + if block.ToolUseID != nil { + if block.Content != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: block.ToolUseID, + }, + } + // Initialize the nested struct before any writes + bifrostMsg.ResponsesToolMessage.Output = &schemas.ResponsesToolMessageOutputStruct{} + + if block.Content.ContentStr != nil { + bifrostMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr = block.Content.ContentStr + } else if block.Content.ContentBlocks != nil { + var toolMsgContentBlocks []schemas.ResponsesMessageContentBlock + for _, contentBlock := range block.Content.ContentBlocks { + switch contentBlock.Type { + case AnthropicContentBlockTypeText: + if contentBlock.Text != nil { + var blockType schemas.ResponsesMessageContentBlockType + if isOutputMessage { + blockType = schemas.ResponsesOutputMessageContentTypeText + } else { + blockType = schemas.ResponsesInputMessageContentBlockTypeText + } + toolMsgContentBlocks = append(toolMsgContentBlocks, schemas.ResponsesMessageContentBlock{ + Type: blockType, + Text: contentBlock.Text, + }) + } + case AnthropicContentBlockTypeImage: + if contentBlock.Source != nil { + toolMsgContentBlocks = append(toolMsgContentBlocks, contentBlock.toBifrostResponsesImageBlock()) + } + } + } + bifrostMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks = toolMsgContentBlocks } - response.Usage.OutputTokensDetails.CachedTokens = chunk.Usage.CacheCreationInputTokens + bifrostMessages = append(bifrostMessages, bifrostMsg) } } - // Use a special response type that indicates this is a message delta - // The integration layer can detect this and pass through the raw event - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseType("anthropic.passthrough.message_delta"), // Custom type for message deltas - SequenceNumber: sequenceNumber, - Response: response, - }}, nil, false - } else { - // Message-level updates (like stop reason, usage, etc.) - // Note: We don't emit output_item.done here because items are already closed - // by content_block_stop. This event is informational only. - return nil, nil, false - } + case AnthropicContentBlockTypeMCPToolUse: + // Accumulate MCP tool use blocks + if block.ID != nil && block.Name != nil { + blockCopy := block + pendingToolUseBlocks = append(pendingToolUseBlocks, &blockCopy) + } - case AnthropicStreamEventTypeMessageStop: - // Message stop - emit response.completed (OpenAI-style) - response := &schemas.BifrostResponsesResponse{ - CreatedAt: state.CreatedAt, - } - if state.MessageID != nil { - response.ID = state.MessageID - } - if state.Model != nil { - response.Model = *state.Model + case AnthropicContentBlockTypeMCPToolResult: + // Handle MCP tool results directly without flushing other blocks + // MCP results will be emitted as separate messages } + } - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeCompleted, - SequenceNumber: sequenceNumber, - Response: response, - }}, nil, true // Indicate stream is complete + // For Bedrock compatibility: reasoning blocks must come before text/tool blocks + // If we have reasoning + text/tools, emit reasoning first, then text/tools + // Otherwise emit them separately as before + if len(reasoningContentBlocks) > 0 { + bifrostMsg := schemas.ResponsesMessage{ + ID: schemas.Ptr("rs_" + utils.GetRandomString(50)), + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: reasoningContentBlocks, + }, + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } - case AnthropicStreamEventTypePing: - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypePing, - SequenceNumber: sequenceNumber, - }}, nil, false + // Flush any remaining pending blocks + if len(accumulatedTextContent) > 0 { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: role, + } + if isOutputMessage { + bifrostMsg.ID = schemas.Ptr("msg_" + utils.GetRandomString(50)) + bifrostMsg.Content = &schemas.ResponsesMessageContent{ + ContentBlocks: accumulatedTextContent, + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + } - case AnthropicStreamEventTypeError: - if chunk.Error != nil { - // Send error event - bifrostErr := &schemas.BifrostError{ - IsBifrostError: false, - Error: &schemas.ErrorField{ - Type: &chunk.Error.Type, - Message: chunk.Error.Message, + // Emit any accumulated tool_use blocks as function_calls + if len(pendingToolUseBlocks) > 0 { + for _, toolBlock := range pendingToolUseBlocks { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: toolBlock.ID, + Name: toolBlock.Name, }, } + if isOutputMessage { + bifrostMsg.ID = schemas.Ptr("msg_" + utils.GetRandomString(50)) + } - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeError, - SequenceNumber: sequenceNumber, - Message: &chunk.Error.Message, - }}, bifrostErr, false + // Check for computer tool use + if toolBlock.Name != nil && *toolBlock.Name == string(AnthropicToolNameComputer) { + bifrostMsg.Type = schemas.Ptr(schemas.ResponsesMessageTypeComputerCall) + bifrostMsg.ResponsesToolMessage.Name = nil + if inputMap, ok := toolBlock.Input.(map[string]interface{}); ok { + bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: convertAnthropicToResponsesComputerAction(inputMap), + } + } + } else { + bifrostMsg.ResponsesToolMessage.Arguments = schemas.Ptr(schemas.JsonifyInput(toolBlock.Input)) + } + + bifrostMessages = append(bifrostMessages, bifrostMsg) } } - return nil, nil, false + return bifrostMessages } -// ToAnthropicResponsesStreamResponse converts a Bifrost Responses stream response to Anthropic SSE string format -func ToAnthropicResponsesStreamResponse(bifrostResp *schemas.BifrostResponsesStreamResponse) *AnthropicStreamEvent { - if bifrostResp == nil { - return nil - } - - streamResp := &AnthropicStreamEvent{} +// Helper function to convert Anthropic content blocks to Bifrost ResponsesMessages +func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicContentBlock, role *schemas.ResponsesMessageRoleType, isOutputMessage bool) []schemas.ResponsesMessage { + var bifrostMessages []schemas.ResponsesMessage + var reasoningContentBlocks []schemas.ResponsesMessageContentBlock - // Map ResponsesStreamResponse types to Anthropic stream events - switch bifrostResp.Type { - case schemas.ResponsesStreamResponseTypeCreated, schemas.ResponsesStreamResponseTypeInProgress: - // These are emitted from message_start - convert back to message_start - streamResp.Type = AnthropicStreamEventTypeMessageStart - if bifrostResp.Response != nil { - streamMessage := &AnthropicMessageResponse{ - Type: "message", - Role: "assistant", + // Process content blocks + for _, block := range contentBlocks { + switch block.Type { + case AnthropicContentBlockTypeText: + if block.Text != nil { + var bifrostMsg schemas.ResponsesMessage + if isOutputMessage { + // For output messages, use ContentBlocks with ResponsesOutputMessageContentTypeText + bifrostMsg = schemas.ResponsesMessage{ + ID: schemas.Ptr("msg_" + utils.GetRandomString(50)), + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: block.Text, + }, + }, + }, + } + } else { + // For input messages, use ContentStr + bifrostMsg = schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: role, + Content: &schemas.ResponsesMessageContent{ + ContentStr: block.Text, + }, + } + } + bifrostMessages = append(bifrostMessages, bifrostMsg) } - if bifrostResp.Response.ID != nil { - streamMessage.ID = *bifrostResp.Response.ID + case AnthropicContentBlockTypeImage: + if block.Source != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{block.toBifrostResponsesImageBlock()}, + }, + } + if isOutputMessage { + bifrostMsg.ID = schemas.Ptr("msg_" + utils.GetRandomString(50)) + } + bifrostMessages = append(bifrostMessages, bifrostMsg) } - // Preserve model from Response if available, otherwise use ExtraFields - if bifrostResp.ExtraFields.ModelRequested != "" { - if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { - streamMessage.Model = bifrostResp.Response.Model + case AnthropicContentBlockTypeThinking: + if block.Thinking != nil { + // Collect reasoning blocks to create a single reasoning message + reasoningContentBlocks = append(reasoningContentBlocks, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: block.Thinking, + Signature: block.Signature, + }) + } + case AnthropicContentBlockTypeRedactedThinking: + if block.Data != nil { + bifrostMsg := schemas.ResponsesMessage{ + ID: schemas.Ptr("rs_" + utils.GetRandomString(50)), + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + EncryptedContent: block.Data, + }, + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + case AnthropicContentBlockTypeToolUse: + // Convert tool use to function call message + if block.ID != nil && block.Name != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: block.ID, + Name: block.Name, + }, + } + if isOutputMessage { + bifrostMsg.ID = schemas.Ptr("msg_" + utils.GetRandomString(50)) + } + + // here need to check for computer tool use + if block.Name != nil && *block.Name == string(AnthropicToolNameComputer) { + bifrostMsg.Type = schemas.Ptr(schemas.ResponsesMessageTypeComputerCall) + bifrostMsg.ResponsesToolMessage.Name = nil + if inputMap, ok := block.Input.(map[string]interface{}); ok { + bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: convertAnthropicToResponsesComputerAction(inputMap), + } + } } else { - streamMessage.Model = bifrostResp.ExtraFields.ModelRequested + bifrostMsg.ResponsesToolMessage.Arguments = schemas.Ptr(schemas.JsonifyInput(block.Input)) } + bifrostMessages = append(bifrostMessages, bifrostMsg) } - // Preserve usage if available - if bifrostResp.Response.Usage != nil { - streamMessage.Usage = &AnthropicUsage{ - InputTokens: bifrostResp.Response.Usage.InputTokens, - OutputTokens: bifrostResp.Response.Usage.OutputTokens, + case AnthropicContentBlockTypeToolResult: + // Convert tool result to function call output message + if block.ToolUseID != nil { + if block.Content != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: block.ToolUseID, + }, + } + // Initialize the nested struct before any writes + bifrostMsg.ResponsesToolMessage.Output = &schemas.ResponsesToolMessageOutputStruct{} + + if block.Content.ContentStr != nil { + bifrostMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr = block.Content.ContentStr + } else if block.Content.ContentBlocks != nil { + var toolMsgContentBlocks []schemas.ResponsesMessageContentBlock + for _, contentBlock := range block.Content.ContentBlocks { + switch contentBlock.Type { + case AnthropicContentBlockTypeText: + if contentBlock.Text != nil { + var blockType schemas.ResponsesMessageContentBlockType + if isOutputMessage { + blockType = schemas.ResponsesOutputMessageContentTypeText + } else { + blockType = schemas.ResponsesInputMessageContentBlockTypeText + } + toolMsgContentBlocks = append(toolMsgContentBlocks, schemas.ResponsesMessageContentBlock{ + Type: blockType, + Text: contentBlock.Text, + }) + } + case AnthropicContentBlockTypeImage: + if contentBlock.Source != nil { + toolMsgContentBlocks = append(toolMsgContentBlocks, contentBlock.toBifrostResponsesImageBlock()) + } + } + } + bifrostMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks = toolMsgContentBlocks + } + bifrostMessages = append(bifrostMessages, bifrostMsg) } - if bifrostResp.Response.Usage.InputTokensDetails != nil && bifrostResp.Response.Usage.InputTokensDetails.CachedTokens > 0 { - streamMessage.Usage.CacheReadInputTokens = bifrostResp.Response.Usage.InputTokensDetails.CachedTokens + } + case AnthropicContentBlockTypeMCPToolUse: + // Convert MCP tool use to MCP call (assistant's tool call) + if block.ID != nil && block.Name != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMCPCall), + ID: block.ID, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + Name: block.Name, + Arguments: schemas.Ptr(schemas.JsonifyInput(block.Input)), + }, } - if bifrostResp.Response.Usage.OutputTokensDetails != nil && bifrostResp.Response.Usage.OutputTokensDetails.CachedTokens > 0 { - streamMessage.Usage.CacheCreationInputTokens = bifrostResp.Response.Usage.OutputTokensDetails.CachedTokens + if block.ServerName != nil { + bifrostMsg.ResponsesToolMessage.ResponsesMCPToolCall = &schemas.ResponsesMCPToolCall{ + ServerLabel: *block.ServerName, + } } + bifrostMessages = append(bifrostMessages, bifrostMsg) } - streamResp.Message = streamMessage + case AnthropicContentBlockTypeMCPToolResult: + // Convert MCP tool result to MCP call (user's tool result) + if block.ToolUseID != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMCPCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: block.ToolUseID, + }, + } + if isOutputMessage { + bifrostMsg.ID = schemas.Ptr("msg_" + utils.GetRandomString(50)) + } + // Initialize the nested struct before any writes + bifrostMsg.ResponsesToolMessage.Output = &schemas.ResponsesToolMessageOutputStruct{} + + if block.Content != nil { + if block.Content.ContentStr != nil { + bifrostMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr = block.Content.ContentStr + } else if block.Content.ContentBlocks != nil { + var toolMsgContentBlocks []schemas.ResponsesMessageContentBlock + for _, contentBlock := range block.Content.ContentBlocks { + if contentBlock.Type == AnthropicContentBlockTypeText { + if contentBlock.Text != nil { + var blockType schemas.ResponsesMessageContentBlockType + if isOutputMessage { + blockType = schemas.ResponsesOutputMessageContentTypeText + } else { + blockType = schemas.ResponsesInputMessageContentBlockTypeText + } + toolMsgContentBlocks = append(toolMsgContentBlocks, schemas.ResponsesMessageContentBlock{ + Type: blockType, + Text: contentBlock.Text, + }) + } + } + } + bifrostMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks = toolMsgContentBlocks + } + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + default: + // Handle other block types if needed } + } - case schemas.ResponsesStreamResponseTypeOutputItemAdded: - // Check if this is a computer tool call - if bifrostResp.Item != nil && - bifrostResp.Item.Type != nil && - *bifrostResp.Item.Type == schemas.ResponsesMessageTypeComputerCall { + // Handle reasoning blocks - prepend reasoning message if we collected any + // This ensures reasoning comes before any text/tool blocks (Bedrock compatibility) + if len(reasoningContentBlocks) > 0 { + reasoningMessage := schemas.ResponsesMessage{ + ID: schemas.Ptr("rs_" + utils.GetRandomString(50)), + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + }, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: reasoningContentBlocks, + }, + } + // Prepend the reasoning message to the start of the messages list + // This ensures reasoning comes before text/tool responses + bifrostMessages = append([]schemas.ResponsesMessage{reasoningMessage}, bifrostMessages...) + } - // Computer tool - emit content_block_start - streamResp.Type = AnthropicStreamEventTypeContentBlockStart + return bifrostMessages +} - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex - } else if bifrostResp.OutputIndex != nil { - streamResp.Index = bifrostResp.OutputIndex +// Helper functions for converting individual Bifrost message types to Anthropic messages +// convertBifrostMessageToAnthropicSystemContent converts a Bifrost system message to Anthropic system content +func convertBifrostMessageToAnthropicSystemContent(msg *schemas.ResponsesMessage) *AnthropicContent { + if msg.Content != nil { + if msg.Content.ContentStr != nil { + return &AnthropicContent{ + ContentStr: msg.Content.ContentStr, } - - // Build the content_block - contentBlock := &AnthropicContentBlock{ - Type: AnthropicContentBlockTypeToolUse, - ID: bifrostResp.Item.ID, // The tool use ID - Name: schemas.Ptr(string(AnthropicToolNameComputer)), // "computer" + } else if msg.Content.ContentBlocks != nil { + contentBlocks := convertBifrostContentBlocksToAnthropic(msg.Content.ContentBlocks) + if len(contentBlocks) > 0 { + return &AnthropicContent{ + ContentBlocks: contentBlocks, + } } + } + } + return nil +} - streamResp.ContentBlock = contentBlock - - } else { - // Text or other content blocks - emit content_block_start - streamResp.Type = AnthropicStreamEventTypeContentBlockStart - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex - } else if bifrostResp.OutputIndex != nil { - streamResp.Index = bifrostResp.OutputIndex +// convertBifrostMessageToAnthropicMessage converts a regular Bifrost message to Anthropic message +func convertBifrostMessageToAnthropicMessage(msg *schemas.ResponsesMessage, pendingReasoningContentBlocks *[]AnthropicContentBlock) *AnthropicMessage { + anthropicMsg := AnthropicMessage{} + + // Set role + if msg.Role != nil { + switch *msg.Role { + case schemas.ResponsesInputMessageRoleUser: + anthropicMsg.Role = AnthropicMessageRoleUser + case schemas.ResponsesInputMessageRoleAssistant: + anthropicMsg.Role = AnthropicMessageRoleAssistant + default: + anthropicMsg.Role = AnthropicMessageRoleUser // Default fallback + } + } else { + anthropicMsg.Role = AnthropicMessageRoleUser // Default fallback + } + + // Add any pending reasoning content blocks to the message + // Only add reasoning blocks to assistant messages (thinking blocks can only appear in assistant messages in Anthropic) + if len(*pendingReasoningContentBlocks) > 0 && anthropicMsg.Role == AnthropicMessageRoleAssistant { + // copy the pending reasoning content blocks + copied := make([]AnthropicContentBlock, len(*pendingReasoningContentBlocks)) + copy(copied, *pendingReasoningContentBlocks) + contentBlocks := copied + *pendingReasoningContentBlocks = nil + // Add content blocks after pending reasoning content blocks are added + if msg.Content != nil { + if msg.Content.ContentStr != nil { + contentBlocks = append(contentBlocks, AnthropicContentBlock{ + Type: AnthropicContentBlockTypeText, + Text: msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + contentBlocks = append(contentBlocks, convertBifrostContentBlocksToAnthropic(msg.Content.ContentBlocks)...) } - - // Build content_block based on item type - if bifrostResp.Item != nil { - contentBlock := &AnthropicContentBlock{} - if bifrostResp.Item.Type != nil { - switch *bifrostResp.Item.Type { - case schemas.ResponsesMessageTypeMessage: - contentBlock.Type = AnthropicContentBlockTypeText - contentBlock.Text = schemas.Ptr("") - case schemas.ResponsesMessageTypeReasoning: - contentBlock.Type = AnthropicContentBlockTypeThinking - contentBlock.Thinking = schemas.Ptr("") - // Preserve signature if present - if bifrostResp.Item.ResponsesReasoning != nil && bifrostResp.Item.ResponsesReasoning.EncryptedContent != nil { - contentBlock.Signature = bifrostResp.Item.ResponsesReasoning.EncryptedContent - } - case schemas.ResponsesMessageTypeFunctionCall: - contentBlock.Type = AnthropicContentBlockTypeToolUse - if bifrostResp.Item.ResponsesToolMessage != nil { - contentBlock.ID = bifrostResp.Item.ResponsesToolMessage.CallID - contentBlock.Name = bifrostResp.Item.ResponsesToolMessage.Name - } - case schemas.ResponsesMessageTypeMCPCall: - contentBlock.Type = AnthropicContentBlockTypeMCPToolUse - if bifrostResp.Item.ResponsesToolMessage != nil { - contentBlock.ID = bifrostResp.Item.ID - contentBlock.Name = bifrostResp.Item.ResponsesToolMessage.Name - if bifrostResp.Item.ResponsesToolMessage.ResponsesMCPToolCall != nil { - contentBlock.ServerName = &bifrostResp.Item.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel - } - } + } + anthropicMsg.Content = AnthropicContent{ + ContentBlocks: contentBlocks, + } + } else { + // Convert content + if msg.Content != nil { + if msg.Content.ContentStr != nil { + anthropicMsg.Content = AnthropicContent{ + ContentStr: msg.Content.ContentStr, + } + } else if msg.Content.ContentBlocks != nil { + contentBlocks := convertBifrostContentBlocksToAnthropic(msg.Content.ContentBlocks) + if len(contentBlocks) > 0 { + anthropicMsg.Content = AnthropicContent{ + ContentBlocks: contentBlocks, } } - if contentBlock.Type != "" { - streamResp.ContentBlock = contentBlock + } + } + } + + return &anthropicMsg +} + +// convertBifrostReasoningToAnthropicThinking converts a Bifrost reasoning message to Anthropic thinking blocks +func convertBifrostReasoningToAnthropicThinking(msg *schemas.ResponsesMessage) []AnthropicContentBlock { + var thinkingBlocks []AnthropicContentBlock + + if msg.Content != nil && msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + if block.Type == schemas.ResponsesOutputMessageContentTypeReasoning && block.Text != nil { + thinkingBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeThinking, + Thinking: block.Text, + Signature: block.Signature, } + thinkingBlocks = append(thinkingBlocks, thinkingBlock) } } - case schemas.ResponsesStreamResponseTypeContentPartAdded: - streamResp.Type = AnthropicStreamEventTypeContentBlockStart - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex - } - if bifrostResp.Part != nil { - contentBlock := &AnthropicContentBlock{} - switch bifrostResp.Part.Type { - case schemas.ResponsesOutputMessageContentTypeText: - contentBlock.Type = AnthropicContentBlockTypeText - if bifrostResp.Part.Text != nil { - contentBlock.Text = bifrostResp.Part.Text + } else if msg.ResponsesReasoning != nil { + if msg.ResponsesReasoning.Summary != nil { + for _, reasoningContent := range msg.ResponsesReasoning.Summary { + thinkingBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeThinking, + Thinking: &reasoningContent.Text, } + thinkingBlocks = append(thinkingBlocks, thinkingBlock) } - streamResp.ContentBlock = contentBlock + } else if msg.ResponsesReasoning.EncryptedContent != nil { + thinkingBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeRedactedThinking, + Data: msg.ResponsesReasoning.EncryptedContent, + } + thinkingBlocks = append(thinkingBlocks, thinkingBlock) } + } - case schemas.ResponsesStreamResponseTypeOutputTextDelta: - streamResp.Type = AnthropicStreamEventTypeContentBlockDelta - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex + return thinkingBlocks +} + +// convertBifrostFunctionCallToAnthropicToolUse converts a Bifrost function call to Anthropic tool use +func convertBifrostFunctionCallToAnthropicToolUse(msg *schemas.ResponsesMessage) *AnthropicContentBlock { + if msg.ResponsesToolMessage != nil { + toolUseBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolUse, } - if bifrostResp.Delta != nil { - streamResp.Delta = &AnthropicStreamDelta{ - Type: AnthropicStreamDeltaTypeText, - Text: bifrostResp.Delta, - } + + if msg.ResponsesToolMessage.CallID != nil { + toolUseBlock.ID = msg.ResponsesToolMessage.CallID + } + if msg.ResponsesToolMessage.Name != nil { + toolUseBlock.Name = msg.ResponsesToolMessage.Name } - case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta: - streamResp.Type = AnthropicStreamEventTypeContentBlockDelta - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex + // Parse arguments as JSON input + if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { + toolUseBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) } - if bifrostResp.Arguments != nil { - streamResp.Delta = &AnthropicStreamDelta{ - Type: AnthropicStreamDeltaTypeInputJSON, - PartialJSON: bifrostResp.Arguments, - } + + return &toolUseBlock + } + return nil +} + +// convertBifrostFunctionCallOutputToAnthropicMessage converts a Bifrost function call output to Anthropic message +func convertBifrostFunctionCallOutputToAnthropicMessage(msg *schemas.ResponsesMessage) *AnthropicMessage { + if msg.ResponsesToolMessage != nil { + toolResultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolResult, + ToolUseID: msg.ResponsesToolMessage.CallID, } - case schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta: - streamResp.Type = AnthropicStreamEventTypeContentBlockDelta - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex + if msg.ResponsesToolMessage.Output != nil { + toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) } - if bifrostResp.Delta != nil { - // Check if this looks like a signature (long base64 string, typically >200 chars) - // Signatures are base64 encoded and much longer than typical thinking text - deltaStr := *bifrostResp.Delta - if len(deltaStr) > 200 && isBase64Like(deltaStr) { - // This is likely a signature_delta - streamResp.Delta = &AnthropicStreamDelta{ - Type: AnthropicStreamDeltaTypeSignature, - Signature: bifrostResp.Delta, - } - } else { - // This is a thinking_delta - streamResp.Delta = &AnthropicStreamDelta{ - Type: AnthropicStreamDeltaTypeThinking, - Thinking: bifrostResp.Delta, - } - } + + return &AnthropicMessage{ + Role: AnthropicMessageRoleUser, + Content: AnthropicContent{ + ContentBlocks: []AnthropicContentBlock{toolResultBlock}, + }, } + } + return nil +} - case schemas.ResponsesStreamResponseTypeContentPartDone: - streamResp.Type = AnthropicStreamEventTypeContentBlockStop - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex +// convertBifrostFunctionCallOutputToAnthropicToolResultBlock converts a Bifrost function call output to a single tool result block +// This is used to accumulate multiple tool results into a single user message +func convertBifrostFunctionCallOutputToAnthropicToolResultBlock(msg *schemas.ResponsesMessage) *AnthropicContentBlock { + if msg.ResponsesToolMessage != nil { + toolResultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolResult, + ToolUseID: msg.ResponsesToolMessage.CallID, } - case schemas.ResponsesStreamResponseTypeOutputItemDone: - if bifrostResp.Item != nil && - bifrostResp.Item.Type != nil && - *bifrostResp.Item.Type == schemas.ResponsesMessageTypeComputerCall { + if msg.ResponsesToolMessage.Output != nil { + toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) + } - // Computer tool complete - emit content_block_delta with the action, then stop - // Note: We're sending the complete action JSON in one delta - streamResp.Type = AnthropicStreamEventTypeContentBlockDelta + return &toolResultBlock + } + return nil +} - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex - } else if bifrostResp.OutputIndex != nil { - streamResp.Index = bifrostResp.OutputIndex - } +// convertBifrostComputerCallOutputToAnthropicToolResultBlock converts a Bifrost computer call output to a single tool result block +// This is used to accumulate multiple tool results into a single user message +func convertBifrostComputerCallOutputToAnthropicToolResultBlock(msg *schemas.ResponsesMessage) *AnthropicContentBlock { + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { + toolResultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolResult, + ToolUseID: msg.ResponsesToolMessage.CallID, + } - // Convert the action to Anthropic format and marshal to JSON - if bifrostResp.Item.ResponsesToolMessage != nil && - bifrostResp.Item.ResponsesToolMessage.Action != nil && - bifrostResp.Item.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil { + // Handle output + if msg.ResponsesToolMessage.Output != nil { + toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) + } - actionInput := convertResponsesToAnthropicComputerAction( - bifrostResp.Item.ResponsesToolMessage.Action.ResponsesComputerToolCallAction, - ) + return &toolResultBlock + } + return nil +} - // Marshal the action to JSON string - if jsonBytes, err := json.Marshal(actionInput); err == nil { - jsonStr := string(jsonBytes) - streamResp.Delta = &AnthropicStreamDelta{ - Type: AnthropicStreamDeltaTypeInputJSON, - PartialJSON: &jsonStr, - } - } - } - } else { - // For text blocks and other content blocks, emit content_block_stop - streamResp.Type = AnthropicStreamEventTypeContentBlockStop - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex - } else if bifrostResp.OutputIndex != nil { - streamResp.Index = bifrostResp.OutputIndex - } +// convertBifrostMCPCallOutputToAnthropicToolResultBlock converts a Bifrost MCP call output to a single tool result block +// This is used to accumulate multiple tool results into a single user message +func convertBifrostMCPCallOutputToAnthropicToolResultBlock(msg *schemas.ResponsesMessage) *AnthropicContentBlock { + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { + toolResultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeMCPToolResult, + ToolUseID: msg.ResponsesToolMessage.CallID, } - case schemas.ResponsesStreamResponseTypePing: - streamResp.Type = AnthropicStreamEventTypePing - case schemas.ResponsesStreamResponseTypeCompleted: - streamResp.Type = AnthropicStreamEventTypeMessageStop + // Handle output + if msg.ResponsesToolMessage.Output != nil { + toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) + } - case schemas.ResponsesStreamResponseTypeMCPCallArgumentsDelta: - // MCP call arguments delta - convert to content_block_delta with input_json - streamResp.Type = AnthropicStreamEventTypeContentBlockDelta - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex - } else if bifrostResp.OutputIndex != nil { - streamResp.Index = bifrostResp.OutputIndex + return &toolResultBlock + } + return nil +} + +// convertBifrostItemReferenceToAnthropicMessage converts a Bifrost item reference to Anthropic message +func convertBifrostItemReferenceToAnthropicMessage(msg *schemas.ResponsesMessage) *AnthropicMessage { + if msg.Content != nil && msg.Content.ContentStr != nil { + referenceMsg := AnthropicMessage{ + Role: AnthropicMessageRoleUser, // Default to user for references } - if bifrostResp.Delta != nil { - streamResp.Delta = &AnthropicStreamDelta{ - Type: AnthropicStreamDeltaTypeInputJSON, - PartialJSON: bifrostResp.Delta, - } + if msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleAssistant { + referenceMsg.Role = AnthropicMessageRoleAssistant } - case schemas.ResponsesStreamResponseTypeMCPCallCompleted: - // MCP call completed - emit content_block_stop - streamResp.Type = AnthropicStreamEventTypeContentBlockStop - if bifrostResp.ContentIndex != nil { - streamResp.Index = bifrostResp.ContentIndex - } else if bifrostResp.OutputIndex != nil { - streamResp.Index = bifrostResp.OutputIndex + referenceMsg.Content = AnthropicContent{ + ContentStr: msg.Content.ContentStr, } - case schemas.ResponsesStreamResponseTypeMCPCallFailed: - // MCP call failed - emit error event - streamResp.Type = AnthropicStreamEventTypeError - errorMsg := "MCP call failed" - if bifrostResp.Message != nil { - errorMsg = *bifrostResp.Message + return &referenceMsg + } + return nil +} + +// convertBifrostComputerCallToAnthropicToolUse converts a Bifrost computer call to Anthropic tool use +func convertBifrostComputerCallToAnthropicToolUse(msg *schemas.ResponsesMessage) *AnthropicContentBlock { + if msg.ResponsesToolMessage != nil { + toolUseBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolUse, + Name: schemas.Ptr(string(AnthropicToolNameComputer)), } - streamResp.Error = &AnthropicStreamError{ - Type: "error", - Message: errorMsg, + if msg.ResponsesToolMessage.CallID != nil { + toolUseBlock.ID = msg.ResponsesToolMessage.CallID + } + if msg.ResponsesToolMessage.Name != nil { + toolUseBlock.Name = msg.ResponsesToolMessage.Name } - case schemas.ResponsesStreamResponseTypeError: - streamResp.Type = AnthropicStreamEventTypeError - if bifrostResp.Message != nil { - streamResp.Error = &AnthropicStreamError{ - Type: "error", - Message: *bifrostResp.Message, - } + if msg.ResponsesToolMessage.Action != nil && msg.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil { + toolUseBlock.Input = convertResponsesToAnthropicComputerAction(msg.ResponsesToolMessage.Action.ResponsesComputerToolCallAction) } - case schemas.ResponsesStreamResponseType("anthropic.passthrough.message_delta"): - // Handle message_delta events - convert back to Anthropic format - streamResp.Type = AnthropicStreamEventTypeMessageDelta - streamResp.Delta = &AnthropicStreamDelta{} + return &toolUseBlock + } + return nil +} - // Include usage information if present - if bifrostResp.Response != nil && bifrostResp.Response.Usage != nil { - streamResp.Usage = &AnthropicUsage{ - InputTokens: bifrostResp.Response.Usage.InputTokens, - OutputTokens: bifrostResp.Response.Usage.OutputTokens, - } - if bifrostResp.Response.Usage.InputTokensDetails != nil && bifrostResp.Response.Usage.InputTokensDetails.CachedTokens > 0 { - streamResp.Usage.CacheReadInputTokens = bifrostResp.Response.Usage.InputTokensDetails.CachedTokens - } - if bifrostResp.Response.Usage.OutputTokensDetails != nil && bifrostResp.Response.Usage.OutputTokensDetails.CachedTokens > 0 { - streamResp.Usage.CacheCreationInputTokens = bifrostResp.Response.Usage.OutputTokensDetails.CachedTokens - } +// convertBifrostMCPCallToAnthropicToolUse converts a Bifrost MCP call to Anthropic tool use +func convertBifrostMCPCallToAnthropicToolUse(msg *schemas.ResponsesMessage) *AnthropicContentBlock { + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { + toolUseBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeMCPToolUse, } - default: - // Unknown event type, return empty - return nil - } + if msg.ID != nil { + toolUseBlock.ID = msg.ID + } + toolUseBlock.Name = msg.ResponsesToolMessage.Name - return streamResp -} + // Set server name if present + if msg.ResponsesToolMessage.ResponsesMCPToolCall != nil && msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel != "" { + toolUseBlock.ServerName = &msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel + } -// ToAnthropicResponsesStreamError converts a BifrostError to Anthropic responses streaming error in SSE format -func ToAnthropicResponsesStreamError(bifrostErr *schemas.BifrostError) string { - if bifrostErr == nil { - return "" - } + // Parse arguments as JSON input + if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { + toolUseBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) + } - // Safely extract message from nested error - message := "" - if bifrostErr.Error != nil { - message = bifrostErr.Error.Message + return &toolUseBlock } + return nil +} - streamResp := &AnthropicStreamEvent{ - Type: AnthropicStreamEventTypeError, - Error: &AnthropicStreamError{ - Type: "error", - Message: message, - }, +// convertBifrostMCPCallOutputToAnthropicMessage converts a Bifrost MCP call output to Anthropic message +func convertBifrostMCPCallOutputToAnthropicMessage(msg *schemas.ResponsesMessage) *AnthropicMessage { + toolResultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeMCPToolResult, + ID: msg.ResponsesToolMessage.CallID, } - // Marshal to JSON - jsonData, err := json.Marshal(streamResp) - if err != nil { - return "" + if msg.ResponsesToolMessage.Output != nil { + toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) } - // Format as Anthropic SSE error event - return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData) + return &AnthropicMessage{ + Role: AnthropicMessageRoleUser, + Content: AnthropicContent{ + ContentBlocks: []AnthropicContentBlock{toolResultBlock}, + }, + } } -// convertAnthropicMessageToBifrostResponsesMessages converts AnthropicMessage to ChatMessage format -func convertAnthropicMessageToBifrostResponsesMessages(msg *AnthropicMessage) []schemas.ResponsesMessage { - var bifrostMessages []schemas.ResponsesMessage +// convertBifrostMCPApprovalToAnthropicToolUse converts a Bifrost MCP approval request to Anthropic tool use +func convertBifrostMCPApprovalToAnthropicToolUse(msg *schemas.ResponsesMessage) *AnthropicContentBlock { + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { + toolUseBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeMCPToolUse, + } - // Handle text content - if msg.Content.ContentStr != nil { - bifrostMsg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - Role: schemas.Ptr(schemas.ResponsesMessageRoleType(msg.Role)), - Content: &schemas.ResponsesMessageContent{ - ContentStr: msg.Content.ContentStr, - }, + if msg.ID != nil { + toolUseBlock.ID = msg.ID } - bifrostMessages = append(bifrostMessages, bifrostMsg) - } else if msg.Content.ContentBlocks != nil { - // Handle content blocks - for _, block := range msg.Content.ContentBlocks { - switch block.Type { - case AnthropicContentBlockTypeText: - if block.Text != nil { - bifrostMsg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - Role: schemas.Ptr(schemas.ResponsesMessageRoleType(msg.Role)), - Content: &schemas.ResponsesMessageContent{ - ContentStr: block.Text, - }, - } - bifrostMessages = append(bifrostMessages, bifrostMsg) - } - case AnthropicContentBlockTypeImage: - if block.Source != nil { - bifrostMsg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - Role: schemas.Ptr(schemas.ResponsesMessageRoleType(msg.Role)), - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{block.toBifrostResponsesImageBlock()}, - }, - } - bifrostMessages = append(bifrostMessages, bifrostMsg) - } - case AnthropicContentBlockTypeToolUse: - // Convert tool use to function call message - if block.ID != nil && block.Name != nil { - bifrostMsg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - Status: schemas.Ptr("completed"), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: block.ID, - Name: block.Name, - }, - } + toolUseBlock.Name = msg.ResponsesToolMessage.Name - // here need to check for computer tool use - if block.Name != nil && *block.Name == string(AnthropicToolNameComputer) { - bifrostMsg.Type = schemas.Ptr(schemas.ResponsesMessageTypeComputerCall) - bifrostMsg.ResponsesToolMessage.Name = nil - if inputMap, ok := block.Input.(map[string]interface{}); ok { - bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ - ResponsesComputerToolCallAction: convertAnthropicToResponsesComputerAction(inputMap), - } - } - } else { - bifrostMsg.ResponsesToolMessage.Arguments = schemas.Ptr(schemas.JsonifyInput(block.Input)) - } - bifrostMessages = append(bifrostMessages, bifrostMsg) - } - case AnthropicContentBlockTypeToolResult: - // Convert tool result to function call output message - if block.ToolUseID != nil { - if block.Content != nil { - bifrostMsg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), - Status: schemas.Ptr("completed"), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: block.ToolUseID, - }, - } - // Initialize the nested struct before any writes - bifrostMsg.ResponsesToolMessage.Output = &schemas.ResponsesToolMessageOutputStruct{} - - if block.Content.ContentStr != nil { - bifrostMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr = block.Content.ContentStr - } else if block.Content.ContentBlocks != nil { - var toolMsgContentBlocks []schemas.ResponsesMessageContentBlock - for _, contentBlock := range block.Content.ContentBlocks { - switch contentBlock.Type { - case AnthropicContentBlockTypeText: - if contentBlock.Text != nil { - toolMsgContentBlocks = append(toolMsgContentBlocks, schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesInputMessageContentBlockTypeText, - Text: contentBlock.Text, - }) - } - case AnthropicContentBlockTypeImage: - if contentBlock.Source != nil { - toolMsgContentBlocks = append(toolMsgContentBlocks, contentBlock.toBifrostResponsesImageBlock()) - } - } - } - bifrostMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks = toolMsgContentBlocks - } - bifrostMessages = append(bifrostMessages, bifrostMsg) - } - } - case AnthropicContentBlockTypeMCPToolUse: - // Convert MCP tool use to MCP call (assistant's tool call) - if block.ID != nil && block.Name != nil { - bifrostMsg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMCPCall), - ID: block.ID, - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - Name: block.Name, - Arguments: schemas.Ptr(schemas.JsonifyInput(block.Input)), - }, - } - if block.ServerName != nil { - bifrostMsg.ResponsesToolMessage.ResponsesMCPToolCall = &schemas.ResponsesMCPToolCall{ - ServerLabel: *block.ServerName, - } - } - bifrostMessages = append(bifrostMessages, bifrostMsg) - } - case AnthropicContentBlockTypeMCPToolResult: - // Convert MCP tool result to MCP call (user's tool result) - if block.ToolUseID != nil { - bifrostMsg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMCPCall), - Status: schemas.Ptr("completed"), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: block.ToolUseID, - }, - } - // Initialize the nested struct before any writes - bifrostMsg.ResponsesToolMessage.Output = &schemas.ResponsesToolMessageOutputStruct{} + // Set server name if present + if msg.ResponsesToolMessage.ResponsesMCPToolCall != nil && msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel != "" { + toolUseBlock.ServerName = &msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel + } - if block.Content != nil { - if block.Content.ContentStr != nil { - bifrostMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr = block.Content.ContentStr - } else if block.Content.ContentBlocks != nil { - var toolMsgContentBlocks []schemas.ResponsesMessageContentBlock - for _, contentBlock := range block.Content.ContentBlocks { - if contentBlock.Type == AnthropicContentBlockTypeText { - if contentBlock.Text != nil { - toolMsgContentBlocks = append(toolMsgContentBlocks, schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesInputMessageContentBlockTypeText, - Text: contentBlock.Text, - }) - } - } - } - bifrostMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks = toolMsgContentBlocks - } - } - bifrostMessages = append(bifrostMessages, bifrostMsg) - } + // Parse arguments as JSON input + if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { + toolUseBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) + } + + return &toolUseBlock + } + return nil +} +// convertBifrostUnsupportedToolCallToAnthropicMessage converts unsupported tool calls to text messages +func convertBifrostUnsupportedToolCallToAnthropicMessage(msg *schemas.ResponsesMessage, msgType schemas.ResponsesMessageType) *AnthropicMessage { + if msg.ResponsesToolMessage != nil { + var description string + if msg.ResponsesToolMessage.Name != nil { + description = fmt.Sprintf("Tool call: %s", *msg.ResponsesToolMessage.Name) + if msg.ResponsesToolMessage.Arguments != nil { + description += fmt.Sprintf(" with arguments: %s", *msg.ResponsesToolMessage.Arguments) } + } else { + description = fmt.Sprintf("Tool call of type: %s", msgType) + } + + return &AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + Content: AnthropicContent{ + ContentStr: &description, + }, } } + return nil +} - return bifrostMessages +// convertBifrostComputerCallOutputToAnthropicMessage converts a Bifrost computer call output to Anthropic message +func convertBifrostComputerCallOutputToAnthropicMessage(msg *schemas.ResponsesMessage) *AnthropicMessage { + if msg.ResponsesToolMessage != nil { + toolResultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolResult, + ToolUseID: msg.ResponsesToolMessage.CallID, + } + + if msg.ResponsesToolMessage.Output != nil { + toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) + } + + return &AnthropicMessage{ + Role: AnthropicMessageRoleUser, + Content: AnthropicContent{ + ContentBlocks: []AnthropicContentBlock{toolResultBlock}, + }, + } + } + return nil +} + +// convertBifrostToolOutputToAnthropicMessage converts tool outputs to user messages +func convertBifrostToolOutputToAnthropicMessage(msg *schemas.ResponsesMessage) *AnthropicMessage { + if msg.ResponsesToolMessage != nil { + var outputText string + // Try to extract output text based on tool type + if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { + outputText = *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr + } + + if outputText != "" { + return &AnthropicMessage{ + Role: AnthropicMessageRoleUser, + Content: AnthropicContent{ + ContentStr: &outputText, + }, + } + } + } + return nil } // convertAnthropicToolToBifrost converts AnthropicTool to schemas.Tool @@ -1733,451 +3119,66 @@ func convertAnthropicToolChoiceToBifrost(toolChoice *AnthropicToolChoice) *schem return bifrostToolChoice } -// flushPendingToolCalls is a helper that flushes accumulated tool calls into an assistant message -func flushPendingToolCalls( - pendingToolCalls []AnthropicContentBlock, +// flushPendingContentBlocks is a helper that flushes accumulated content blocks into an assistant message +func flushPendingContentBlocks( + pendingContentBlocks []AnthropicContentBlock, currentAssistantMessage *AnthropicMessage, anthropicMessages []AnthropicMessage, ) ([]AnthropicContentBlock, *AnthropicMessage, []AnthropicMessage) { - if len(pendingToolCalls) > 0 && currentAssistantMessage != nil { + if len(pendingContentBlocks) > 0 && currentAssistantMessage != nil { // Copy the slice to avoid aliasing issues - copied := make([]AnthropicContentBlock, len(pendingToolCalls)) - copy(copied, pendingToolCalls) + copied := make([]AnthropicContentBlock, len(pendingContentBlocks)) + copy(copied, pendingContentBlocks) currentAssistantMessage.Content = AnthropicContent{ ContentBlocks: copied, } anthropicMessages = append(anthropicMessages, *currentAssistantMessage) // Return nil values to indicate flushed state - return nil, nil, anthropicMessages - } - // Return unchanged values if no flush was needed - return pendingToolCalls, currentAssistantMessage, anthropicMessages -} - -// convertToolOutputToAnthropicContent converts tool output to Anthropic content format -func convertToolOutputToAnthropicContent(output *schemas.ResponsesToolMessageOutputStruct) *AnthropicContent { - if output == nil { - return nil - } - - if output.ResponsesToolCallOutputStr != nil { - return &AnthropicContent{ - ContentStr: output.ResponsesToolCallOutputStr, - } - } - - if output.ResponsesFunctionToolCallOutputBlocks != nil { - var resultBlocks []AnthropicContentBlock - for _, block := range output.ResponsesFunctionToolCallOutputBlocks { - if converted := convertContentBlockToAnthropic(block); converted != nil { - resultBlocks = append(resultBlocks, *converted) - } - } - if len(resultBlocks) > 0 { - return &AnthropicContent{ - ContentBlocks: resultBlocks, - } - } - } - - if output.ResponsesComputerToolCallOutput != nil && output.ResponsesComputerToolCallOutput.ImageURL != nil { - imgBlock := ConvertToAnthropicImageBlock(schemas.ChatContentBlock{ - Type: schemas.ChatContentBlockTypeImage, - ImageURLStruct: &schemas.ChatInputImage{ - URL: *output.ResponsesComputerToolCallOutput.ImageURL, - }, - }) - return &AnthropicContent{ - ContentBlocks: []AnthropicContentBlock{imgBlock}, - } - } - - return nil -} - -// Helper function to convert ResponsesInputItems back to AnthropicMessages -func convertResponsesMessagesToAnthropicMessages(messages []schemas.ResponsesMessage) ([]AnthropicMessage, *AnthropicContent) { - var anthropicMessages []AnthropicMessage - var systemContent *AnthropicContent - var pendingToolCalls []AnthropicContentBlock - var currentAssistantMessage *AnthropicMessage - - for _, msg := range messages { - // Handle nil Type as regular message - msgType := schemas.ResponsesMessageTypeMessage - if msg.Type != nil { - msgType = *msg.Type - } - - switch msgType { - case schemas.ResponsesMessageTypeMessage: - // Flush any pending tool calls first - pendingToolCalls, currentAssistantMessage, anthropicMessages = flushPendingToolCalls( - pendingToolCalls, currentAssistantMessage, anthropicMessages) - - // Handle system messages separately - if msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleSystem { - if msg.Content != nil { - if msg.Content.ContentStr != nil { - systemContent = &AnthropicContent{ - ContentStr: msg.Content.ContentStr, - } - } else if msg.Content.ContentBlocks != nil { - contentBlocks := convertBifrostContentBlocksToAnthropic(msg.Content.ContentBlocks) - if len(contentBlocks) > 0 { - systemContent = &AnthropicContent{ - ContentBlocks: contentBlocks, - } - } - } - } - continue - } - - // Regular user/assistant message - anthropicMsg := AnthropicMessage{} - - // Set role - if msg.Role != nil { - switch *msg.Role { - case schemas.ResponsesInputMessageRoleUser: - anthropicMsg.Role = AnthropicMessageRoleUser - case schemas.ResponsesInputMessageRoleAssistant: - anthropicMsg.Role = AnthropicMessageRoleAssistant - default: - anthropicMsg.Role = AnthropicMessageRoleUser // Default fallback - } - } else { - anthropicMsg.Role = AnthropicMessageRoleUser // Default fallback - } - - // Convert content - if msg.Content != nil { - if msg.Content.ContentStr != nil { - anthropicMsg.Content = AnthropicContent{ - ContentStr: msg.Content.ContentStr, - } - } else if msg.Content.ContentBlocks != nil { - contentBlocks := convertBifrostContentBlocksToAnthropic(msg.Content.ContentBlocks) - if len(contentBlocks) > 0 { - anthropicMsg.Content = AnthropicContent{ - ContentBlocks: contentBlocks, - } - } - } - } - - anthropicMessages = append(anthropicMessages, anthropicMsg) - - case schemas.ResponsesMessageTypeReasoning: - // Handle reasoning as thinking content - if msg.ResponsesReasoning != nil && len(msg.ResponsesReasoning.Summary) > 0 { - // Find the last assistant message or create one - var targetMsg *AnthropicMessage - if len(anthropicMessages) > 0 && anthropicMessages[len(anthropicMessages)-1].Role == AnthropicMessageRoleAssistant { - targetMsg = &anthropicMessages[len(anthropicMessages)-1] - } else { - // Create new assistant message for reasoning - newMsg := AnthropicMessage{ - Role: AnthropicMessageRoleAssistant, - } - anthropicMessages = append(anthropicMessages, newMsg) - targetMsg = &anthropicMessages[len(anthropicMessages)-1] - } - - // Add thinking blocks - var contentBlocks []AnthropicContentBlock - if targetMsg.Content.ContentBlocks != nil { - contentBlocks = targetMsg.Content.ContentBlocks - } - - for _, reasoningContent := range msg.ResponsesReasoning.Summary { - thinkingBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeThinking, - Thinking: &reasoningContent.Text, - } - contentBlocks = append(contentBlocks, thinkingBlock) - } - - targetMsg.Content = AnthropicContent{ - ContentBlocks: contentBlocks, - } - } - - case schemas.ResponsesMessageTypeFunctionCall: - // Start accumulating tool calls for assistant message - if currentAssistantMessage == nil { - currentAssistantMessage = &AnthropicMessage{ - Role: AnthropicMessageRoleAssistant, - } - } - - if msg.ResponsesToolMessage != nil { - toolUseBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeToolUse, - } - - if msg.ResponsesToolMessage.CallID != nil { - toolUseBlock.ID = msg.ResponsesToolMessage.CallID - } - if msg.ResponsesToolMessage.Name != nil { - toolUseBlock.Name = msg.ResponsesToolMessage.Name - } - - // Parse arguments as JSON input - if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { - toolUseBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) - } - - pendingToolCalls = append(pendingToolCalls, toolUseBlock) - } - - case schemas.ResponsesMessageTypeFunctionCallOutput: - // Flush any pending tool calls first before processing tool results - pendingToolCalls, currentAssistantMessage, anthropicMessages = flushPendingToolCalls( - pendingToolCalls, currentAssistantMessage, anthropicMessages) - - // Handle tool call output - convert to user message with tool_result - if msg.ResponsesToolMessage != nil { - toolResultBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeToolResult, - ToolUseID: msg.ResponsesToolMessage.CallID, - } - - if msg.ResponsesToolMessage.Output != nil { - toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) - } - - toolResultMsg := AnthropicMessage{ - Role: AnthropicMessageRoleUser, - Content: AnthropicContent{ - ContentBlocks: []AnthropicContentBlock{toolResultBlock}, - }, - } - - anthropicMessages = append(anthropicMessages, toolResultMsg) - } - - case schemas.ResponsesMessageTypeItemReference: - // Handle item reference as regular text message - if msg.Content != nil && msg.Content.ContentStr != nil { - referenceMsg := AnthropicMessage{ - Role: AnthropicMessageRoleUser, // Default to user for references - } - if msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleAssistant { - referenceMsg.Role = AnthropicMessageRoleAssistant - } - - referenceMsg.Content = AnthropicContent{ - ContentStr: msg.Content.ContentStr, - } - - anthropicMessages = append(anthropicMessages, referenceMsg) - } - case schemas.ResponsesMessageTypeComputerCall: - // Start accumulating tool calls for assistant message - if currentAssistantMessage == nil { - currentAssistantMessage = &AnthropicMessage{ - Role: AnthropicMessageRoleAssistant, - } - } - - if msg.ResponsesToolMessage != nil { - toolUseBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeToolUse, - Name: schemas.Ptr(string(AnthropicToolNameComputer)), - } - if msg.ResponsesToolMessage.CallID != nil { - toolUseBlock.ID = msg.ResponsesToolMessage.CallID - } - if msg.ResponsesToolMessage.Name != nil { - toolUseBlock.Name = msg.ResponsesToolMessage.Name - } - - if msg.ResponsesToolMessage.Action != nil && msg.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil { - toolUseBlock.Input = convertResponsesToAnthropicComputerAction(msg.ResponsesToolMessage.Action.ResponsesComputerToolCallAction) - } - - pendingToolCalls = append(pendingToolCalls, toolUseBlock) - } - - case schemas.ResponsesMessageTypeMCPCall: - // Check if this is a tool use (from assistant) or tool result (from user) - // Tool use: has Name and Arguments but no Output - // Tool result: has CallID and Output - if msg.ResponsesToolMessage != nil { - // This is a tool use call (assistant calling a tool) - if msg.ResponsesToolMessage.Name != nil { - // Start accumulating MCP tool calls for assistant message - if currentAssistantMessage == nil { - currentAssistantMessage = &AnthropicMessage{ - Role: AnthropicMessageRoleAssistant, - } - } - - toolUseBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeMCPToolUse, - } - - if msg.ID != nil { - toolUseBlock.ID = msg.ID - } - toolUseBlock.Name = msg.ResponsesToolMessage.Name - - // Set server name if present - if msg.ResponsesToolMessage.ResponsesMCPToolCall != nil && msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel != "" { - toolUseBlock.ServerName = &msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel - } - - // Parse arguments as JSON input - if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { - toolUseBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) - } - - pendingToolCalls = append(pendingToolCalls, toolUseBlock) - } else if msg.ResponsesToolMessage.CallID != nil { - // This is a tool result (user providing result of tool execution) - toolResultBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeMCPToolResult, - ID: msg.ResponsesToolMessage.CallID, - } - - if msg.ResponsesToolMessage.Output != nil { - toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) - } - - toolResultMsg := AnthropicMessage{ - Role: AnthropicMessageRoleUser, - Content: AnthropicContent{ - ContentBlocks: []AnthropicContentBlock{toolResultBlock}, - }, - } - - anthropicMessages = append(anthropicMessages, toolResultMsg) - } - } - - case schemas.ResponsesMessageTypeMCPApprovalRequest: - // MCP approval request is OpenAI-specific for human-in-the-loop workflows - // Convert to Anthropic's mcp_tool_use format (same as regular MCP calls) - if currentAssistantMessage == nil { - currentAssistantMessage = &AnthropicMessage{ - Role: AnthropicMessageRoleAssistant, - } - } - - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { - toolUseBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeMCPToolUse, - } - - if msg.ID != nil { - toolUseBlock.ID = msg.ID - } - toolUseBlock.Name = msg.ResponsesToolMessage.Name - - // Set server name if present - if msg.ResponsesToolMessage.ResponsesMCPToolCall != nil && msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel != "" { - toolUseBlock.ServerName = &msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel - } - - // Parse arguments as JSON input - if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { - toolUseBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) - } - - pendingToolCalls = append(pendingToolCalls, toolUseBlock) - } - - // Handle other tool call types that are not natively supported by Anthropic - case schemas.ResponsesMessageTypeFileSearchCall, - schemas.ResponsesMessageTypeCodeInterpreterCall, - schemas.ResponsesMessageTypeWebSearchCall, - schemas.ResponsesMessageTypeLocalShellCall, - schemas.ResponsesMessageTypeCustomToolCall, - schemas.ResponsesMessageTypeImageGenerationCall: - // Convert unsupported tool calls to regular text messages - if msg.ResponsesToolMessage != nil { - toolCallMsg := AnthropicMessage{ - Role: AnthropicMessageRoleAssistant, - } - - var description string - if msg.ResponsesToolMessage.Name != nil { - description = fmt.Sprintf("Tool call: %s", *msg.ResponsesToolMessage.Name) - if msg.ResponsesToolMessage.Arguments != nil { - description += fmt.Sprintf(" with arguments: %s", *msg.ResponsesToolMessage.Arguments) - } - } else { - description = fmt.Sprintf("Tool call of type: %s", msgType) - } - - toolCallMsg.Content = AnthropicContent{ - ContentStr: &description, - } - - anthropicMessages = append(anthropicMessages, toolCallMsg) - } - - case schemas.ResponsesMessageTypeComputerCallOutput: - // Flush any pending tool calls first before processing tool results - pendingToolCalls, currentAssistantMessage, anthropicMessages = flushPendingToolCalls( - pendingToolCalls, currentAssistantMessage, anthropicMessages) - - // Handle computer call output - convert to user message with tool_result - if msg.ResponsesToolMessage != nil { - toolResultBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeToolResult, - ToolUseID: msg.ResponsesToolMessage.CallID, - } + return nil, nil, anthropicMessages + } + // Return unchanged values if no flush was needed + return pendingContentBlocks, currentAssistantMessage, anthropicMessages +} - if msg.ResponsesToolMessage.Output != nil { - toolResultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) - } +// convertToolOutputToAnthropicContent converts tool output to Anthropic content format +func convertToolOutputToAnthropicContent(output *schemas.ResponsesToolMessageOutputStruct) *AnthropicContent { + if output == nil { + return nil + } - toolResultMsg := AnthropicMessage{ - Role: AnthropicMessageRoleUser, - Content: AnthropicContent{ - ContentBlocks: []AnthropicContentBlock{toolResultBlock}, - }, - } + if output.ResponsesToolCallOutputStr != nil { + return &AnthropicContent{ + ContentStr: output.ResponsesToolCallOutputStr, + } + } - anthropicMessages = append(anthropicMessages, toolResultMsg) + if output.ResponsesFunctionToolCallOutputBlocks != nil { + var resultBlocks []AnthropicContentBlock + for _, block := range output.ResponsesFunctionToolCallOutputBlocks { + if converted := convertContentBlockToAnthropic(block); converted != nil { + resultBlocks = append(resultBlocks, *converted) } - - case schemas.ResponsesMessageTypeLocalShellCallOutput, - schemas.ResponsesMessageTypeCustomToolCallOutput: - // Handle tool outputs as user messages - if msg.ResponsesToolMessage != nil { - toolOutputMsg := AnthropicMessage{ - Role: AnthropicMessageRoleUser, - } - - var outputText string - // Try to extract output text based on tool type - if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { - outputText = *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr - } - - if outputText != "" { - toolOutputMsg.Content = AnthropicContent{ - ContentStr: &outputText, - } - anthropicMessages = append(anthropicMessages, toolOutputMsg) - } + } + if len(resultBlocks) > 0 { + return &AnthropicContent{ + ContentBlocks: resultBlocks, } - - default: - // Skip unknown message types or log them for debugging - continue } } - // Flush any remaining pending tool calls - pendingToolCalls, currentAssistantMessage, anthropicMessages = flushPendingToolCalls( - pendingToolCalls, currentAssistantMessage, anthropicMessages) + if output.ResponsesComputerToolCallOutput != nil && output.ResponsesComputerToolCallOutput.ImageURL != nil { + imgBlock := ConvertToAnthropicImageBlock(schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ + URL: *output.ResponsesComputerToolCallOutput.ImageURL, + }, + }) + return &AnthropicContent{ + ContentBlocks: []AnthropicContentBlock{imgBlock}, + } + } - return anthropicMessages, systemContent + return nil } // Helper function to convert Tool back to AnthropicTool @@ -2317,388 +3318,6 @@ func convertResponsesToolChoiceToAnthropic(toolChoice *schemas.ResponsesToolChoi return anthropicChoice } -// Helper function to convert Anthropic content blocks to Responses output messages -func convertAnthropicContentBlocksToResponsesMessages(content []AnthropicContentBlock) []schemas.ResponsesMessage { - var messages []schemas.ResponsesMessage - - for _, block := range content { - switch block.Type { - case AnthropicContentBlockTypeText: - if block.Text != nil { - // Append text to existing message - messages = append(messages, schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: block.Text, - }, - }, - }, - }) - } - - case AnthropicContentBlockTypeImage: - if block.Source != nil { - messages = append(messages, schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - block.toBifrostResponsesImageBlock(), - }, - }, - }) - } - - case AnthropicContentBlockTypeThinking: - if block.Thinking != nil { - // Create reasoning message - messages = append(messages, schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesOutputMessageContentTypeReasoning, - Text: block.Thinking, - }, - }, - }, - ResponsesReasoning: &schemas.ResponsesReasoning{ - Summary: []schemas.ResponsesReasoningContent{ - { - Text: *block.Thinking, - Type: schemas.ResponsesReasoningContentBlockTypeSummaryText, - }, - }, - EncryptedContent: block.Signature, - }, - }) - } - - case AnthropicContentBlockTypeToolUse: - if block.ID != nil && block.Name != nil { - // Create function call message - message := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - Status: schemas.Ptr("completed"), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: block.ID, - Name: block.Name, - }, - } - - if block.Name != nil && *block.Name == string(AnthropicToolNameComputer) { - message.Type = schemas.Ptr(schemas.ResponsesMessageTypeComputerCall) - message.ResponsesToolMessage.Name = nil - if inputMap, ok := block.Input.(map[string]interface{}); ok { - message.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ - ResponsesComputerToolCallAction: convertAnthropicToResponsesComputerAction(inputMap), - } - } - } else { - message.ResponsesToolMessage.Arguments = schemas.Ptr(schemas.JsonifyInput(block.Input)) - } - - messages = append(messages, message) - } - case AnthropicContentBlockTypeToolResult: - if block.ToolUseID != nil { - // Create function call output message - msg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), - Status: schemas.Ptr("completed"), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: block.ToolUseID, - }, - } - // Initialize nested output struct - msg.ResponsesToolMessage.Output = &schemas.ResponsesToolMessageOutputStruct{} - if block.Content != nil { - if block.Content.ContentStr != nil { - msg.ResponsesToolMessage.Output. - ResponsesToolCallOutputStr = block.Content.ContentStr - } else if block.Content.ContentBlocks != nil { - var outBlocks []schemas.ResponsesMessageContentBlock - for _, cb := range block.Content.ContentBlocks { - switch cb.Type { - case AnthropicContentBlockTypeText: - if cb.Text != nil { - outBlocks = append(outBlocks, schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesInputMessageContentBlockTypeText, - Text: cb.Text, - }) - } - case AnthropicContentBlockTypeImage: - if cb.Source != nil { - outBlocks = append(outBlocks, cb.toBifrostResponsesImageBlock()) - } - } - } - msg.ResponsesToolMessage.Output. - ResponsesFunctionToolCallOutputBlocks = outBlocks - } - } - messages = append(messages, msg) - } - - case AnthropicContentBlockTypeMCPToolUse: - if block.ID != nil && block.Name != nil { - // Create MCP call message (tool invocation from assistant) - message := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMCPCall), - ID: block.ID, - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - Name: block.Name, - Arguments: schemas.Ptr(schemas.JsonifyInput(block.Input)), - }, - } - - // Set server name if present - if block.ServerName != nil { - message.ResponsesToolMessage.ResponsesMCPToolCall = &schemas.ResponsesMCPToolCall{ - ServerLabel: *block.ServerName, - } - } - - messages = append(messages, message) - } - - case AnthropicContentBlockTypeMCPToolResult: - if block.ToolUseID != nil { - // Create MCP call message (tool result) - msg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMCPCall), - Status: schemas.Ptr("completed"), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: block.ToolUseID, - }, - } - // Initialize nested output struct - msg.ResponsesToolMessage.Output = &schemas.ResponsesToolMessageOutputStruct{} - if block.Content != nil { - if block.Content.ContentStr != nil { - msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr = block.Content.ContentStr - } else if block.Content.ContentBlocks != nil { - var outBlocks []schemas.ResponsesMessageContentBlock - for _, cb := range block.Content.ContentBlocks { - if cb.Type == AnthropicContentBlockTypeText { - if cb.Text != nil { - outBlocks = append(outBlocks, schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: cb.Text, - }) - } - } - } - msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks = outBlocks - } - } - messages = append(messages, msg) - } - - default: - // Handle other block types if needed - } - } - return messages -} - -// Helper function to convert ChatMessage output to Anthropic content blocks -func convertBifrostMessagesToAnthropicContent(messages []schemas.ResponsesMessage) []AnthropicContentBlock { - var contentBlocks []AnthropicContentBlock - - for _, msg := range messages { - // Handle different message types based on Responses structure - if msg.Type != nil { - switch *msg.Type { - case schemas.ResponsesMessageTypeMessage: - // Regular text message - if msg.Content != nil { - if msg.Content.ContentStr != nil { - contentBlocks = append(contentBlocks, AnthropicContentBlock{ - Type: "text", - Text: msg.Content.ContentStr, - }) - } else if msg.Content.ContentBlocks != nil { - // Convert content blocks - for _, block := range msg.Content.ContentBlocks { - anthropicBlock := convertContentBlockToAnthropic(block) - if anthropicBlock != nil { - contentBlocks = append(contentBlocks, *anthropicBlock) - } - } - } - } - - case schemas.ResponsesMessageTypeFunctionCall: - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { - toolBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeToolUse, - ID: msg.ResponsesToolMessage.CallID, - } - if msg.ResponsesToolMessage.Name != nil { - toolBlock.Name = msg.ResponsesToolMessage.Name - } - if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { - toolBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) - } - contentBlocks = append(contentBlocks, toolBlock) - } - - case schemas.ResponsesMessageTypeFunctionCallOutput: - // Tool result block - need to extract from ToolMessage - resultBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeToolResult, - } - - if msg.ResponsesToolMessage != nil { - resultBlock.ToolUseID = msg.ResponsesToolMessage.CallID - // Try content from msg.Content first, then Output - if msg.Content != nil && msg.Content.ContentStr != nil { - resultBlock.Content = &AnthropicContent{ - ContentStr: msg.Content.ContentStr, - } - } else if msg.ResponsesToolMessage.Output != nil { - resultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) - } - } else if msg.Content != nil && msg.Content.ContentStr != nil { - // Fallback to msg.Content when ResponsesToolMessage is nil - resultBlock.Content = &AnthropicContent{ - ContentStr: msg.Content.ContentStr, - } - } - - contentBlocks = append(contentBlocks, resultBlock) - - case schemas.ResponsesMessageTypeReasoning: - // Build thinking from ResponsesReasoning summary, else from reasoning content blocks - var thinking string - var signature *string - if msg.ResponsesReasoning != nil && msg.ResponsesReasoning.Summary != nil { - for _, b := range msg.ResponsesReasoning.Summary { - thinking += b.Text - } - signature = msg.ResponsesReasoning.EncryptedContent - } else if msg.Content != nil && msg.Content.ContentBlocks != nil { - for _, b := range msg.Content.ContentBlocks { - if b.Type == schemas.ResponsesOutputMessageContentTypeReasoning && b.Text != nil { - thinking += *b.Text - } - } - } - if thinking != "" { - contentBlocks = append(contentBlocks, AnthropicContentBlock{ - Type: AnthropicContentBlockTypeThinking, - Thinking: &thinking, - Signature: signature, - }) - } - - case schemas.ResponsesMessageTypeComputerCall: - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { - toolBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeToolUse, - ID: msg.ResponsesToolMessage.CallID, - Name: schemas.Ptr(string(AnthropicToolNameComputer)), - } - - // Convert computer action to Anthropic input format - if msg.ResponsesToolMessage.Action != nil && msg.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil { - toolBlock.Input = convertResponsesToAnthropicComputerAction(msg.ResponsesToolMessage.Action.ResponsesComputerToolCallAction) - } - contentBlocks = append(contentBlocks, toolBlock) - } - - case schemas.ResponsesMessageTypeMCPCall: - // Check if this is a tool use (from assistant) or tool result (from user) - // Tool use: has Name and Arguments but no Output - // Tool result: has CallID and Output - if msg.ResponsesToolMessage != nil { - if msg.ResponsesToolMessage.Name != nil { - // This is a tool use call (assistant calling a tool) - toolUseBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeMCPToolUse, - } - - if msg.ID != nil { - toolUseBlock.ID = msg.ID - } - - if msg.ResponsesToolMessage.Name != nil { - toolUseBlock.Name = msg.ResponsesToolMessage.Name - } - - // Set server name if present - if msg.ResponsesToolMessage.ResponsesMCPToolCall != nil && msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel != "" { - toolUseBlock.ServerName = &msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel - } - - // Parse arguments as JSON input - if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { - toolUseBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) - } - - contentBlocks = append(contentBlocks, toolUseBlock) - } else if msg.ResponsesToolMessage.CallID != nil { - // This is a tool result (user providing result of tool execution) - resultBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeMCPToolResult, - ToolUseID: msg.ResponsesToolMessage.CallID, - } - - if msg.ResponsesToolMessage.Output != nil { - resultBlock.Content = convertToolOutputToAnthropicContent(msg.ResponsesToolMessage.Output) - } - - contentBlocks = append(contentBlocks, resultBlock) - } - } - - case schemas.ResponsesMessageTypeMCPApprovalRequest: - // MCP approval request is OpenAI-specific for human-in-the-loop workflows - // Convert to Anthropic's mcp_tool_use format (same as regular MCP calls) - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { - toolUseBlock := AnthropicContentBlock{ - Type: AnthropicContentBlockTypeMCPToolUse, - } - - if msg.ID != nil { - toolUseBlock.ID = msg.ID - } - toolUseBlock.Name = msg.ResponsesToolMessage.Name - - // Set server name if present - if msg.ResponsesToolMessage.ResponsesMCPToolCall != nil && msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel != "" { - toolUseBlock.ServerName = &msg.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel - } - - // Parse arguments as JSON input - if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { - toolUseBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) - } - - contentBlocks = append(contentBlocks, toolUseBlock) - } - - default: - // Handle other types as text if they have content - if msg.Content != nil && msg.Content.ContentStr != nil { - contentBlocks = append(contentBlocks, AnthropicContentBlock{ - Type: AnthropicContentBlockTypeText, - Text: msg.Content.ContentStr, - }) - } - } - } - } - - return contentBlocks -} - // Helper function to convert ContentBlock to AnthropicContentBlock func convertContentBlockToAnthropic(block schemas.ResponsesMessageContentBlock) *AnthropicContentBlock { switch block.Type { @@ -3056,18 +3675,38 @@ func convertAnthropicToResponsesComputerAction(inputMap map[string]interface{}) return action } -// isBase64Like checks if a string looks like base64 encoded data -// Signatures are typically long base64 strings (>200 chars) -func isBase64Like(s string) bool { - if len(s) < 100 { - return false - } - // Check if string contains only base64 characters - base64Chars := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=" - for _, char := range s { - if !strings.ContainsRune(base64Chars, char) { - return false - } +// generateSyntheticInputJSONDeltas creates synthetic input_json_delta events from complete JSON arguments +// This simulates the streaming behavior that Anthropic provides natively +func generateSyntheticInputJSONDeltas(argumentsJSON string, contentIndex *int) []*AnthropicStreamEvent { + var events []*AnthropicStreamEvent + + // Chunk size for synthetic streaming (similar to how Anthropic chunks arguments) + chunkSize := 8 // Small chunks to simulate realistic streaming + + // Start with empty delta to match Anthropic's behavior + events = append(events, &AnthropicStreamEvent{ + Type: AnthropicStreamEventTypeContentBlockDelta, + Index: contentIndex, + Delta: &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeInputJSON, + PartialJSON: schemas.Ptr(""), + }, + }) + + // Break the JSON into chunks + for i := 0; i < len(argumentsJSON); i += chunkSize { + end := min(i+chunkSize, len(argumentsJSON)) + + chunk := argumentsJSON[i:end] + events = append(events, &AnthropicStreamEvent{ + Type: AnthropicStreamEventTypeContentBlockDelta, + Index: contentIndex, + Delta: &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeInputJSON, + PartialJSON: &chunk, + }, + }) } - return true + + return events } diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index 738d219e0..1fffd91c6 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" ) @@ -54,6 +55,9 @@ type AnthropicMessageRequest struct { Thinking *AnthropicThinking `json:"thinking,omitempty"` OutputFormat interface{} `json:"output_format,omitempty"` // This feature requires the beta header: "anthropic-beta": "structured-outputs-2025-11-13" and currently only supported for Claude Sonnet 4.5 and Claude Opus 4.1 + // Extra params for advanced use cases + ExtraParams map[string]interface{} `json:"extra_params,omitempty"` + // Bifrost specific field (only parsed when converting from Provider -> Bifrost request) Fallbacks []string `json:"fallbacks,omitempty"` } @@ -72,6 +76,69 @@ func (mr *AnthropicMessageRequest) IsStreamingRequested() bool { return mr.Stream != nil && *mr.Stream } +// Known fields for AnthropicMessageRequest +var anthropicMessageRequestKnownFields = map[string]bool{ + "model": true, + "max_tokens": true, + "messages": true, + "metadata": true, + "system": true, + "temperature": true, + "top_p": true, + "top_k": true, + "stop_sequences": true, + "stream": true, + "tools": true, + "tool_choice": true, + "mcp_servers": true, + "thinking": true, + "output_format": true, + "extra_params": true, + "fallbacks": true, +} + +// UnmarshalJSON implements custom JSON unmarshalling for AnthropicMessageRequest. +// This captures all unregistered fields into ExtraParams. +func (mr *AnthropicMessageRequest) UnmarshalJSON(data []byte) error { + // Create an alias type to avoid infinite recursion + type Alias AnthropicMessageRequest + + // First, unmarshal into the alias to populate all known fields + aux := &struct { + *Alias + }{ + Alias: (*Alias)(mr), + } + + if err := sonic.Unmarshal(data, aux); err != nil { + return err + } + + // Parse JSON to extract unknown fields + var rawData map[string]json.RawMessage + if err := sonic.Unmarshal(data, &rawData); err != nil { + return err + } + + // Initialize ExtraParams if not already initialized + if mr.ExtraParams == nil { + mr.ExtraParams = make(map[string]interface{}) + } + + // Extract unknown fields + for key, value := range rawData { + if !anthropicMessageRequestKnownFields[key] { + var v interface{} + if err := sonic.Unmarshal(value, &v); err != nil { + continue // Skip fields that can't be unmarshaled + } + mr.ExtraParams[key] = v + } + } + + return nil +} + type AnthropicMessageRole string const ( @@ -100,13 +167,13 @@ func (mc AnthropicContent) MarshalJSON() ([]byte, error) { } if mc.ContentStr != nil { - return json.Marshal(*mc.ContentStr) + return sonic.Marshal(*mc.ContentStr) } if mc.ContentBlocks != nil { - return json.Marshal(mc.ContentBlocks) + return sonic.Marshal(mc.ContentBlocks) } // If both are nil, return null - return json.Marshal(nil) + return sonic.Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for AnthropicContent. @@ -114,14 +181,14 @@ func (mc AnthropicContent) MarshalJSON() ([]byte, error) { func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := json.Unmarshal(data, &stringContent); err == nil { + if err := sonic.Unmarshal(data, &stringContent); err == nil { mc.ContentStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var arrayContent []AnthropicContentBlock - if err := json.Unmarshal(data, &arrayContent); err == nil { + if err := sonic.Unmarshal(data, &arrayContent); err == nil { mc.ContentBlocks = arrayContent return nil } @@ -132,15 +199,16 @@ func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { type AnthropicContentBlockType string const ( - AnthropicContentBlockTypeText AnthropicContentBlockType = "text" - AnthropicContentBlockTypeImage AnthropicContentBlockType = "image" - AnthropicContentBlockTypeToolUse AnthropicContentBlockType = "tool_use" - AnthropicContentBlockTypeServerToolUse AnthropicContentBlockType = "server_tool_use" - AnthropicContentBlockTypeToolResult AnthropicContentBlockType = "tool_result" - AnthropicContentBlockTypeWebSearchResult AnthropicContentBlockType = "web_search_result" - AnthropicContentBlockTypeMCPToolUse AnthropicContentBlockType = "mcp_tool_use" - AnthropicContentBlockTypeMCPToolResult AnthropicContentBlockType = "mcp_tool_result" - AnthropicContentBlockTypeThinking AnthropicContentBlockType = "thinking" + AnthropicContentBlockTypeText AnthropicContentBlockType = "text" + AnthropicContentBlockTypeImage AnthropicContentBlockType = "image" + AnthropicContentBlockTypeToolUse AnthropicContentBlockType = "tool_use" + AnthropicContentBlockTypeServerToolUse AnthropicContentBlockType = "server_tool_use" + AnthropicContentBlockTypeToolResult AnthropicContentBlockType = "tool_result" + AnthropicContentBlockTypeWebSearchResult AnthropicContentBlockType = "web_search_result" + AnthropicContentBlockTypeMCPToolUse AnthropicContentBlockType = "mcp_tool_use" + AnthropicContentBlockTypeMCPToolResult AnthropicContentBlockType = "mcp_tool_result" + AnthropicContentBlockTypeThinking AnthropicContentBlockType = "thinking" + AnthropicContentBlockTypeRedactedThinking AnthropicContentBlockType = "redacted_thinking" ) // AnthropicContentBlock represents content in Anthropic message format @@ -149,6 +217,7 @@ type AnthropicContentBlock struct { Text *string `json:"text,omitempty"` // For text content Thinking *string `json:"thinking,omitempty"` // For thinking content Signature *string `json:"signature,omitempty"` // For signature content + Data *string `json:"data,omitempty"` // For data content (encrypted data for redacted thinking, signature does not come with this) ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content ID *string `json:"id,omitempty"` // For tool_use content Name *string `json:"name,omitempty"` // For tool_use content @@ -295,11 +364,11 @@ type AnthropicTextResponse struct { // AnthropicUsage represents usage information in Anthropic format type AnthropicUsage struct { - InputTokens int `json:"input_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` - CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` - CacheCreation *AnthropicUsageCacheCreation `json:"cache_creation,omitempty"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + CacheCreation AnthropicUsageCacheCreation `json:"cache_creation"` + OutputTokens int `json:"output_tokens"` } type AnthropicUsageCacheCreation struct { @@ -345,13 +414,13 @@ const ( // AnthropicStreamDelta represents incremental updates to content blocks during streaming (legacy) type AnthropicStreamDelta struct { - Type AnthropicStreamDeltaType `json:"type"` + Type AnthropicStreamDeltaType `json:"type,omitempty"` Text *string `json:"text,omitempty"` PartialJSON *string `json:"partial_json,omitempty"` Thinking *string `json:"thinking,omitempty"` Signature *string `json:"signature,omitempty"` StopReason *AnthropicStopReason `json:"stop_reason,omitempty"` // only not present in "message_start" events - StopSequence *string `json:"stop_sequence,omitempty"` + StopSequence *string `json:"stop_sequence"` } // ==================== MODEL TYPES ==================== diff --git a/core/providers/anthropic/utils.go b/core/providers/anthropic/utils.go index 3fa27c859..f8052a4a1 100644 --- a/core/providers/anthropic/utils.go +++ b/core/providers/anthropic/utils.go @@ -1,8 +1,12 @@ package anthropic import ( + "context" "encoding/json" + "fmt" + "github.com/bytedance/sonic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -23,6 +27,54 @@ var ( } ) +func getRequestBodyForResponses(ctx context.Context, request *schemas.BifrostResponsesRequest, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { + var jsonBody []byte + var err error + + // Check if raw request body should be used + if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody { + jsonBody = request.GetRawRequestBody() + // Unmarshal and check if model and region are present + var requestBody map[string]interface{} + if err := sonic.Unmarshal(jsonBody, &requestBody); err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, fmt.Errorf("failed to unmarshal request body: %w", err), providerName) + } + // Add max_tokens if not present + if _, exists := requestBody["max_tokens"]; !exists { + requestBody["max_tokens"] = AnthropicDefaultMaxTokens + } + // Add stream if not present + if isStreaming { + requestBody["stream"] = true + } + jsonBody, err = sonic.Marshal(requestBody) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + } + } else { + // Convert request to Anthropic format + reqBody, err := ToAnthropicResponsesRequest(request) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerName) + } + if reqBody == nil { + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + } + + if isStreaming { + reqBody.Stream = schemas.Ptr(true) + } + + // Convert struct to map + jsonBody, err = sonic.Marshal(reqBody) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + } + } + + return jsonBody, nil +} + // ConvertAnthropicFinishReasonToBifrost converts provider finish reasons to Bifrost format func ConvertAnthropicFinishReasonToBifrost(providerReason AnthropicStopReason) string { if bifrostReason, ok := anthropicFinishReasonToBifrost[providerReason]; ok { diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 870ee0d4f..74bb7d72f 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -545,28 +545,23 @@ func (provider *AzureProvider) Responses(ctx context.Context, key schemas.Key, r return nil, err } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( - ctx, - request, - func() (any, error) { - if schemas.IsAnthropicModel(deployment) { - reqBody, err := anthropic.ToAnthropicResponsesRequest(request) - if err != nil { - return nil, err - } - if reqBody != nil { - reqBody.Model = deployment - } - return reqBody, nil - } else { + var jsonData []byte + var bifrostErr *schemas.BifrostError + if schemas.IsAnthropicModel(deployment) { + jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), false) + } else { + jsonData, bifrostErr = providerUtils.CheckContextAndGetRequestBody( + ctx, + request, + func() (any, error) { reqBody := openai.ToOpenAIResponsesRequest(request) if reqBody != nil { reqBody.Model = deployment } return reqBody, nil - } - }, - provider.GetProviderKey()) + }, + provider.GetProviderKey()) + } if bifrostErr != nil { return nil, bifrostErr } @@ -652,23 +647,9 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint) - jsonData, err := providerUtils.CheckContextAndGetRequestBody( - ctx, - request, - func() (any, error) { - reqBody, err := anthropic.ToAnthropicResponsesRequest(request) - if err != nil { - return nil, err - } - if reqBody != nil { - reqBody.Model = deployment - reqBody.Stream = schemas.Ptr(true) - } - return reqBody, nil - }, - provider.GetProviderKey()) - if err != nil { - return nil, err + jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), true) + if bifrostErr != nil { + return nil, bifrostErr } // Use shared streaming logic from Anthropic diff --git a/core/providers/azure/utils.go b/core/providers/azure/utils.go new file mode 100644 index 000000000..b5784cc5f --- /dev/null +++ b/core/providers/azure/utils.go @@ -0,0 +1,63 @@ +package azure + +import ( + "context" + "fmt" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/providers/anthropic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" + "github.com/maximhq/bifrost/core/schemas" +) + +func getRequestBodyForAnthropicResponses(ctx context.Context, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { + var jsonBody []byte + var err error + + // Check if raw request body should be used + if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody { + jsonBody = request.GetRawRequestBody() + // Unmarshal and check if model and region are present + var requestBody map[string]interface{} + if err := sonic.Unmarshal(jsonBody, &requestBody); err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, fmt.Errorf("failed to unmarshal request body: %w", err), providerName) + } + // Add max_tokens if not present + if _, exists := requestBody["max_tokens"]; !exists { + requestBody["max_tokens"] = anthropic.AnthropicDefaultMaxTokens + } + // Replace model with deployment + requestBody["model"] = deployment + // Add stream if not present + if isStreaming { + requestBody["stream"] = true + } + jsonBody, err = sonic.Marshal(requestBody) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + } + } else { + // Convert request to Anthropic format + reqBody, err := anthropic.ToAnthropicResponsesRequest(request) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerName) + } + if reqBody == nil { + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + } + + // Set deployment as model + reqBody.Model = deployment + if isStreaming { + reqBody.Stream = schemas.Ptr(true) + } + + // Convert struct to map + jsonBody, err = sonic.Marshal(reqBody) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + } + } + + return jsonBody, nil +} diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index c8d9cf2c2..9dbcb8a12 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -1035,7 +1035,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu // Decode a single EventStream message message, err := decoder.Decode(resp.Body, payloadBuf) if err != nil { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) if err == io.EOF { // End of stream - finalize any open items finalResponses := FinalizeBedrockStream(streamState, chunkIndex, usage) @@ -1057,6 +1056,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu if i == len(finalResponses)-1 { // Set raw request if enabled + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonData) } @@ -1067,6 +1067,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu } break } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error decoding %s EventStream message: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) return diff --git a/core/providers/bedrock/bedrock_test.go b/core/providers/bedrock/bedrock_test.go index 4c1861c70..455d47039 100644 --- a/core/providers/bedrock/bedrock_test.go +++ b/core/providers/bedrock/bedrock_test.go @@ -484,6 +484,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Model: "claude-3-sonnet", Input: []schemas.ResponsesMessage{ { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), Content: &schemas.ResponsesMessageContent{ ContentBlocks: []schemas.ResponsesMessageContentBlock{ @@ -523,6 +524,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Model: "claude-3-sonnet", Input: []schemas.ResponsesMessage{ { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), Content: &schemas.ResponsesMessageContent{ ContentBlocks: []schemas.ResponsesMessageContentBlock{ @@ -534,6 +536,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { }, }, { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), Content: &schemas.ResponsesMessageContent{ ContentBlocks: []schemas.ResponsesMessageContentBlock{ @@ -573,6 +576,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Model: "claude-3-sonnet", Input: []schemas.ResponsesMessage{ { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), Content: &schemas.ResponsesMessageContent{ ContentBlocks: []schemas.ResponsesMessageContentBlock{ @@ -617,6 +621,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Model: "claude-3-sonnet", Input: []schemas.ResponsesMessage{ { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), Content: &schemas.ResponsesMessageContent{ ContentBlocks: []schemas.ResponsesMessageContentBlock{ @@ -680,6 +685,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Model: "claude-3-sonnet", Input: []schemas.ResponsesMessage{ { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), Content: &schemas.ResponsesMessageContent{ ContentBlocks: []schemas.ResponsesMessageContentBlock{ @@ -749,6 +755,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Model: "claude-3-sonnet", Input: []schemas.ResponsesMessage{ { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), Content: &schemas.ResponsesMessageContent{ ContentBlocks: []schemas.ResponsesMessageContentBlock{ @@ -778,7 +785,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { "requestMetadata": map[string]string{ "user": "test-user", }, - "additionalModelRequestFieldPaths": map[string]interface{}{ + "additionalModelRequestFieldPaths": schemas.OrderedMap{ "customField": "customValue", }, "additionalModelResponseFieldPaths": []string{"field1", "field2"}, @@ -813,7 +820,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Input: []schemas.ResponsesMessage{ { Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - Status: schemas.Ptr("in_progress"), + Status: schemas.Ptr("completed"), ResponsesToolMessage: &schemas.ResponsesToolMessage{ CallID: schemas.Ptr("tool-use-123"), Name: schemas.Ptr("get_weather"), @@ -851,17 +858,11 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Model: "claude-3-sonnet", Input: []schemas.ResponsesMessage{ { - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), - Status: schemas.Ptr("completed"), + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), ResponsesToolMessage: &schemas.ResponsesToolMessage{ CallID: schemas.Ptr("tool-use-123"), Output: &schemas.ResponsesToolMessageOutputStruct{ - ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesInputMessageContentBlockTypeText, - Text: schemas.Ptr("The weather in NYC is sunny, 72°F"), - }, - }, + ResponsesToolCallOutputStr: schemas.Ptr("The weather in NYC is sunny, 72°F"), }, }, }, @@ -911,7 +912,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Input: []schemas.ResponsesMessage{ { Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - Status: schemas.Ptr("in_progress"), + Status: schemas.Ptr("completed"), ResponsesToolMessage: &schemas.ResponsesToolMessage{ CallID: schemas.Ptr("tool-use-456"), Name: schemas.Ptr("calculate"), @@ -919,17 +920,11 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { }, }, { - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), - Status: schemas.Ptr("completed"), + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), ResponsesToolMessage: &schemas.ResponsesToolMessage{ CallID: schemas.Ptr("tool-use-456"), Output: &schemas.ResponsesToolMessageOutputStruct{ - ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesInputMessageContentBlockTypeText, - Text: schemas.Ptr("4"), - }, - }, + ResponsesToolCallOutputStr: schemas.Ptr("4"), }, }, }, @@ -1500,7 +1495,22 @@ func TestBifrostToBedrockResponseConversion(t *testing.T) { } } else { require.NoError(t, err) - assert.Equal(t, tt.expected, actual) + // Compare structure instead of exact equality since IDs may be generated + if tt.expected != nil && actual != nil { + assert.Equal(t, tt.expected.StopReason, actual.StopReason) + assert.Equal(t, tt.expected.Output.Message.Role, actual.Output.Message.Role) + assert.Equal(t, len(tt.expected.Output.Message.Content), len(actual.Output.Message.Content)) + if tt.expected.Usage != nil { + assert.Equal(t, tt.expected.Usage.InputTokens, actual.Usage.InputTokens) + assert.Equal(t, tt.expected.Usage.OutputTokens, actual.Usage.OutputTokens) + assert.Equal(t, tt.expected.Usage.TotalTokens, actual.Usage.TotalTokens) + } + if tt.expected.Metrics != nil { + assert.Equal(t, tt.expected.Metrics.LatencyMs, actual.Metrics.LatencyMs) + } + } else { + assert.Equal(t, tt.expected, actual) + } } }) } @@ -1671,12 +1681,41 @@ func TestBedrockToBifrostResponseConversion(t *testing.T) { } } else { require.NoError(t, err) - // Note: CreatedAt is set to current time, so we can't compare it exactly + // Note: CreatedAt and IDs are set at runtime, so compare structure instead if actual != nil { assert.Greater(t, actual.CreatedAt, 0) actual.CreatedAt = tt.expected.CreatedAt + + // For output messages, IDs are generated, so we need to compare by value not identity + if len(actual.Output) > 0 && len(tt.expected.Output) > 0 { + assert.Equal(t, len(tt.expected.Output), len(actual.Output)) + for i := range actual.Output { + assert.Equal(t, tt.expected.Output[i].Type, actual.Output[i].Type) + assert.Equal(t, tt.expected.Output[i].Role, actual.Output[i].Role) + assert.Equal(t, tt.expected.Output[i].Status, actual.Output[i].Status) + if tt.expected.Output[i].ResponsesToolMessage != nil { + assert.NotNil(t, actual.Output[i].ResponsesToolMessage) + require.NotNil(t, actual.Output[i].ResponsesToolMessage.Name) + require.NotNil(t, actual.Output[i].ResponsesToolMessage.CallID) + require.NotNil(t, actual.Output[i].ResponsesToolMessage.Arguments) + assert.Equal(t, *tt.expected.Output[i].ResponsesToolMessage.Name, *actual.Output[i].ResponsesToolMessage.Name) + assert.Equal(t, *tt.expected.Output[i].ResponsesToolMessage.CallID, *actual.Output[i].ResponsesToolMessage.CallID) + assert.Equal(t, *tt.expected.Output[i].ResponsesToolMessage.Arguments, *actual.Output[i].ResponsesToolMessage.Arguments) + } + if tt.expected.Output[i].Content != nil { + assert.Equal(t, tt.expected.Output[i].Content, actual.Output[i].Content) + } + } + } + + // Compare usage if present + if tt.expected.Usage != nil { + assert.NotNil(t, actual.Usage) + assert.Equal(t, tt.expected.Usage.InputTokens, actual.Usage.InputTokens) + assert.Equal(t, tt.expected.Usage.OutputTokens, actual.Usage.OutputTokens) + assert.Equal(t, tt.expected.Usage.TotalTokens, actual.Usage.TotalTokens) + } } - assert.Equal(t, tt.expected, actual) } }) } diff --git a/core/providers/bedrock/responses.go b/core/providers/bedrock/responses.go index d028bff48..955a585ad 100644 --- a/core/providers/bedrock/responses.go +++ b/core/providers/bedrock/responses.go @@ -6,7 +6,10 @@ import ( "sync" "time" + "github.com/bytedance/sonic" + "github.com/google/uuid" "github.com/maximhq/bifrost/core/providers/anthropic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -17,13 +20,15 @@ type BedrockResponsesStreamState struct { ItemIDs map[int]string // Maps output_index to item ID for stable IDs ToolCallIDs map[int]string // Maps output_index to tool call ID (callID) ToolCallNames map[int]string // Maps output_index to tool call name + ReasoningContentIndices map[int]bool // Tracks which content indices are reasoning blocks + CompletedOutputIndices map[int]bool // Tracks which output indices have been completed CurrentOutputIndex int // Current output index counter MessageID *string // Message ID (generated) Model *string // Model name + StopReason *string // Stop reason for the message CreatedAt int // Timestamp for created_at consistency HasEmittedCreated bool // Whether we've emitted response.created HasEmittedInProgress bool // Whether we've emitted response.in_progress - TextItemClosed bool // Whether text item has been closed } // bedrockResponsesStreamStatePool provides a pool for Bedrock responses stream state objects. @@ -35,11 +40,12 @@ var bedrockResponsesStreamStatePool = sync.Pool{ ItemIDs: make(map[int]string), ToolCallIDs: make(map[int]string), ToolCallNames: make(map[int]string), + ReasoningContentIndices: make(map[int]bool), + CompletedOutputIndices: make(map[int]bool), CurrentOutputIndex: 0, CreatedAt: int(time.Now().Unix()), HasEmittedCreated: false, HasEmittedInProgress: false, - TextItemClosed: false, } }, } @@ -74,14 +80,24 @@ func acquireBedrockResponsesStreamState() *BedrockResponsesStreamState { } else { clear(state.ToolCallNames) } + if state.ReasoningContentIndices == nil { + state.ReasoningContentIndices = make(map[int]bool) + } else { + clear(state.ReasoningContentIndices) + } + if state.CompletedOutputIndices == nil { + state.CompletedOutputIndices = make(map[int]bool) + } else { + clear(state.CompletedOutputIndices) + } // Reset other fields state.CurrentOutputIndex = 0 state.MessageID = nil state.Model = nil + state.StopReason = nil state.CreatedAt = int(time.Now().Unix()) state.HasEmittedCreated = false state.HasEmittedInProgress = false - state.TextItemClosed = false return state } @@ -120,912 +136,1487 @@ func (state *BedrockResponsesStreamState) flush() { } else { clear(state.ToolCallNames) } + if state.ReasoningContentIndices == nil { + state.ReasoningContentIndices = make(map[int]bool) + } else { + clear(state.ReasoningContentIndices) + } + if state.CompletedOutputIndices == nil { + state.CompletedOutputIndices = make(map[int]bool) + } else { + clear(state.CompletedOutputIndices) + } state.CurrentOutputIndex = 0 state.MessageID = nil state.Model = nil + state.StopReason = nil state.CreatedAt = int(time.Now().Unix()) state.HasEmittedCreated = false state.HasEmittedInProgress = false - state.TextItemClosed = false } -// ToBifrostResponsesRequest converts a BedrockConverseRequest to Bifrost Responses Request format -func (request *BedrockConverseRequest) ToBifrostResponsesRequest() (*schemas.BifrostResponsesRequest, error) { - if request == nil { - return nil, fmt.Errorf("bedrock request is nil") - } +// ToBifrostResponsesStream converts a Bedrock stream event to a Bifrost Responses Stream response +// Returns a slice of responses to support cases where a single event produces multiple responses +func (chunk *BedrockStreamEvent) ToBifrostResponsesStream(sequenceNumber int, state *BedrockResponsesStreamState) ([]*schemas.BifrostResponsesStreamResponse, *schemas.BifrostError, bool) { + switch { + case chunk.Role != nil: + // Message start - emit response.created and response.in_progress (OpenAI-style lifecycle) + var responses []*schemas.BifrostResponsesStreamResponse - // Extract provider from model ID (format: "bedrock/model-name") - provider, model := schemas.ParseModelString(request.ModelID, schemas.Bedrock) + // Generate message ID if not already set + if state.MessageID == nil { + messageID := fmt.Sprintf("msg_%d", state.CreatedAt) + state.MessageID = &messageID + } - bifrostReq := &schemas.BifrostResponsesRequest{ - Provider: provider, - Model: model, - Params: &schemas.ResponsesParameters{}, - } - - // Convert system messages first (they should appear at the top) - if len(request.System) > 0 { - for _, sysMsg := range request.System { - if sysMsg.Text != nil { - systemRole := schemas.ResponsesInputMessageRoleSystem - bifrostMsg := schemas.ResponsesMessage{ - Role: &systemRole, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesInputMessageContentBlockTypeText, - Text: sysMsg.Text, - }, - }, - }, - } - bifrostReq.Input = append(bifrostReq.Input, bifrostMsg) + // Emit response.created + if !state.HasEmittedCreated { + response := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, + } + if state.Model != nil { + response.Model = *state.Model } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCreated, + SequenceNumber: sequenceNumber, + Response: response, + }) + state.HasEmittedCreated = true } - } - // Convert regular messages - for _, msg := range request.Messages { - var role schemas.ResponsesMessageRoleType - switch msg.Role { - case BedrockMessageRoleUser: - role = schemas.ResponsesInputMessageRoleUser - case BedrockMessageRoleAssistant: - role = schemas.ResponsesInputMessageRoleAssistant - default: - continue + // Emit response.in_progress + if !state.HasEmittedInProgress { + response := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, // Use same timestamp + } + if state.Model != nil { + response.Model = *state.Model + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeInProgress, + SequenceNumber: sequenceNumber + len(responses), + Response: response, + }) + state.HasEmittedInProgress = true } - // Convert content blocks - each content block may become its own message - // (e.g., tool calls and tool results are separate items in Responses API) - for _, content := range msg.Content { - if content.ToolUse != nil { - // Tool calls are separate function_call type items (no role field) - callType := schemas.ResponsesMessageTypeFunctionCall - callStatus := "in_progress" - toolCallMsg := schemas.ResponsesMessage{ - Type: &callType, - Status: &callStatus, - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: &content.ToolUse.ToolUseID, - Name: &content.ToolUse.Name, - Arguments: schemas.Ptr(schemas.JsonifyInput(content.ToolUse.Input)), - }, + // Don't pre-create any items here - let each content block create its own item when it first appears + + if len(responses) > 0 { + return responses, nil, false + } + + case chunk.Start != nil: + // Handle content block start (text content or tool use) + contentBlockIndex := 0 + if chunk.ContentBlockIndex != nil { + contentBlockIndex = *chunk.ContentBlockIndex + } + + // Check if this is a tool use start + if chunk.Start.ToolUse != nil { + var responses []*schemas.BifrostResponsesStreamResponse + + // Close any open reasoning blocks first (Anthropic sends content_block_stop before starting new blocks) + for prevContentIndex := range state.ReasoningContentIndices { + prevOutputIndex, prevExists := state.ContentIndexToOutputIndex[prevContentIndex] + if !prevExists { + continue } - bifrostReq.Input = append(bifrostReq.Input, toolCallMsg) - } else if content.ToolResult != nil { - // Tool results are separate function_call_output type items (no role field) - resultType := schemas.ResponsesMessageTypeFunctionCallOutput - resultStatus := "completed" - var toolResultContent []schemas.ResponsesMessageContentBlock - for _, resultContent := range content.ToolResult.Content { - if resultContent.Text != nil { - toolResultContent = append(toolResultContent, schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesInputMessageContentBlockTypeText, - Text: resultContent.Text, - }) - } + + // Skip already completed output indices + if state.CompletedOutputIndices[prevOutputIndex] { + continue } - toolResultMsg := schemas.ResponsesMessage{ - Type: &resultType, - Status: &resultStatus, - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: &content.ToolResult.ToolUseID, - Output: &schemas.ResponsesToolMessageOutputStruct{ - ResponsesFunctionToolCallOutputBlocks: toolResultContent, - }, - }, + + itemID := state.ItemIDs[prevOutputIndex] + + // For reasoning items, content_index is always 0 + reasoningContentIndex := 0 + + // Emit reasoning_summary_text.done + emptyText := "" + reasoningDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(prevOutputIndex), + ContentIndex: &reasoningContentIndex, + Text: &emptyText, } - bifrostReq.Input = append(bifrostReq.Input, toolResultMsg) - } else if content.Text != nil { - // Regular text content - use role-based message - // For assistant messages (previous model outputs), use output_text type - // For user/system messages, use input_text type - textBlockType := schemas.ResponsesInputMessageContentBlockTypeText - if msg.Role == BedrockMessageRoleAssistant { - textBlockType = schemas.ResponsesOutputMessageContentTypeText + if itemID != "" { + reasoningDoneResponse.ItemID = &itemID } - bifrostMsg := schemas.ResponsesMessage{ - Role: &role, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: textBlockType, - Text: content.Text, - }, - }, - }, + responses = append(responses, reasoningDoneResponse) + + // Emit content_part.done for reasoning + partDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(prevOutputIndex), + ContentIndex: &reasoningContentIndex, } - bifrostReq.Input = append(bifrostReq.Input, bifrostMsg) - } else if content.Image != nil { - // Image content - use role-based message - imageBlock := schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesInputMessageContentBlockTypeImage, - } - if content.Image.Source.Bytes != nil { - // Construct proper data URI from format and base64 bytes - // Format should be "png", "jpeg", "gif", "webp" etc. - format := content.Image.Format - if format == "" { - format = "jpeg" // default to jpeg if format not specified - } - dataURI := fmt.Sprintf("data:image/%s;base64,%s", format, *content.Image.Source.Bytes) - imageBlock.ResponsesInputMessageContentBlockImage = &schemas.ResponsesInputMessageContentBlockImage{ - ImageURL: &dataURI, - } + if itemID != "" { + partDoneResponse.ItemID = &itemID } - bifrostMsg := schemas.ResponsesMessage{ - Role: &role, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{imageBlock}, - }, + responses = append(responses, partDoneResponse) + + // Emit output_item.done for reasoning + statusCompleted := "completed" + doneItem := &schemas.ResponsesMessage{ + Status: &statusCompleted, } - bifrostReq.Input = append(bifrostReq.Input, bifrostMsg) - } - } - } + if itemID != "" { + doneItem.ID = &itemID + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(prevOutputIndex), + ContentIndex: &reasoningContentIndex, + Item: doneItem, + }) - // Convert inference config to parameters - if request.InferenceConfig != nil { - if request.InferenceConfig.MaxTokens != nil { - bifrostReq.Params.MaxOutputTokens = request.InferenceConfig.MaxTokens - } - if request.InferenceConfig.Temperature != nil { - bifrostReq.Params.Temperature = request.InferenceConfig.Temperature - } - if request.InferenceConfig.TopP != nil { - bifrostReq.Params.TopP = request.InferenceConfig.TopP - } - if len(request.InferenceConfig.StopSequences) > 0 { - if bifrostReq.Params.ExtraParams == nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) + // Mark this output index as completed + state.CompletedOutputIndices[prevOutputIndex] = true } - bifrostReq.Params.ExtraParams["stop"] = request.InferenceConfig.StopSequences - } - } + // Clear reasoning content indices after closing them + clear(state.ReasoningContentIndices) - // Convert tool config - if request.ToolConfig != nil && len(request.ToolConfig.Tools) > 0 { - for _, tool := range request.ToolConfig.Tools { - if tool.ToolSpec != nil { - bifrostTool := schemas.ResponsesTool{ - Type: schemas.ResponsesToolTypeFunction, - Name: &tool.ToolSpec.Name, - Description: tool.ToolSpec.Description, - ResponsesToolFunction: &schemas.ResponsesToolFunction{}, + // Close any open tool call blocks before starting a new one (Anthropic completes each block before starting next) + for prevContentIndex, prevOutputIndex := range state.ContentIndexToOutputIndex { + // Skip reasoning blocks (already handled above) + if state.ReasoningContentIndices[prevContentIndex] { + continue } - // Handle different types for InputSchema.JSON - if params, ok := tool.ToolSpec.InputSchema.JSON.(*schemas.ToolFunctionParameters); ok { - bifrostTool.ResponsesToolFunction.Parameters = params - } else if paramsMap, ok := tool.ToolSpec.InputSchema.JSON.(map[string]interface{}); ok { - // Convert map to ToolFunctionParameters - params := &schemas.ToolFunctionParameters{} - if typeVal, ok := paramsMap["type"].(string); ok { - params.Type = typeVal - } - // Handle both pointer and non-pointer properties - if props, ok := schemas.SafeExtractOrderedMap(paramsMap["properties"]); ok { - params.Properties = &props - } - if required, ok := paramsMap["required"].([]interface{}); ok { - reqStrings := make([]string, 0, len(required)) - for _, r := range required { - if rStr, ok := r.(string); ok { - reqStrings = append(reqStrings, rStr) - } - } - params.Required = reqStrings - } else if required, ok := paramsMap["required"].([]string); ok { - params.Required = required - } - bifrostTool.ResponsesToolFunction.Parameters = params + // Skip already completed output indices + if state.CompletedOutputIndices[prevOutputIndex] { + continue } - bifrostReq.Params.Tools = append(bifrostReq.Params.Tools, bifrostTool) - } - } - } + // Check if this is a tool call + prevToolCallID := state.ToolCallIDs[prevOutputIndex] + if prevToolCallID == "" { + continue // Not a tool call + } - // Convert guardrail config to extra params - if request.GuardrailConfig != nil { - if bifrostReq.Params.ExtraParams == nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) - } + prevItemID := state.ItemIDs[prevOutputIndex] + prevToolName := state.ToolCallNames[prevOutputIndex] + accumulatedArgs := state.ToolArgumentBuffers[prevOutputIndex] - guardrailMap := map[string]interface{}{ - "guardrailIdentifier": request.GuardrailConfig.GuardrailIdentifier, - "guardrailVersion": request.GuardrailConfig.GuardrailVersion, - } - if request.GuardrailConfig.Trace != nil { - guardrailMap["trace"] = *request.GuardrailConfig.Trace - } - bifrostReq.Params.ExtraParams["guardrailConfig"] = guardrailMap - } + // Emit content_part.done for tool call + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(prevOutputIndex), + ContentIndex: schemas.Ptr(prevContentIndex), + ItemID: &prevItemID, + }) - // Convert additional model request fields to extra params - if len(request.AdditionalModelRequestFields) > 0 { - reasoningConfig, ok := schemas.SafeExtractFromMap(request.AdditionalModelRequestFields, "reasoning_config") - if ok { - if reasoningConfigMap, ok := reasoningConfig.(map[string]interface{}); ok { - if typeStr, ok := schemas.SafeExtractString(reasoningConfigMap["type"]); ok { - if typeStr == "enabled" { - if maxTokens, ok := schemas.SafeExtractInt(reasoningConfigMap["budget_tokens"]); ok { - bifrostReq.Params.Reasoning = &schemas.ResponsesParametersReasoning{ - Effort: schemas.Ptr("auto"), - MaxTokens: schemas.Ptr(maxTokens), - } + // Emit function_call_arguments.done with full arguments + if accumulatedArgs != "" { + var doneItem *schemas.ResponsesMessage + if prevToolCallID != "" || prevToolName != "" { + doneItem = &schemas.ResponsesMessage{ + ResponsesToolMessage: &schemas.ResponsesToolMessage{}, } - } else { - bifrostReq.Params.Reasoning = &schemas.ResponsesParametersReasoning{ - Effort: schemas.Ptr("none"), + if prevToolCallID != "" { + doneItem.ResponsesToolMessage.CallID = &prevToolCallID } + if prevToolName != "" { + doneItem.ResponsesToolMessage.Name = &prevToolName + } + } + + argsDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(prevOutputIndex), + Arguments: &accumulatedArgs, + } + if prevItemID != "" { + argsDoneResponse.ItemID = &prevItemID } + if doneItem != nil { + argsDoneResponse.Item = doneItem + } + responses = append(responses, argsDoneResponse) } - } - } - } - // Convert performance config to extra params - if request.PerformanceConfig != nil { - if bifrostReq.Params.ExtraParams == nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) - } + // Emit output_item.done for tool call + statusCompleted := "completed" + toolDoneItem := &schemas.ResponsesMessage{ + ID: &prevItemID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: &statusCompleted, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &prevToolCallID, + Name: &prevToolName, + Arguments: &accumulatedArgs, + }, + } - perfConfigMap := map[string]interface{}{} - if request.PerformanceConfig.Latency != nil { - perfConfigMap["latency"] = *request.PerformanceConfig.Latency - } - if len(perfConfigMap) > 0 { - bifrostReq.Params.ExtraParams["performanceConfig"] = perfConfigMap - } - } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(prevOutputIndex), + ContentIndex: schemas.Ptr(prevContentIndex), + ItemID: &prevItemID, + Item: toolDoneItem, + }) - // Convert prompt variables to extra params - if len(request.PromptVariables) > 0 { - if bifrostReq.Params.ExtraParams == nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) - } + // Mark this output index as completed + state.CompletedOutputIndices[prevOutputIndex] = true + } - promptVarsMap := make(map[string]interface{}) - for key, value := range request.PromptVariables { - varMap := map[string]interface{}{} - if value.Text != nil { - varMap["text"] = *value.Text - } - if len(varMap) > 0 { - promptVarsMap[key] = varMap - } - } - if len(promptVarsMap) > 0 { - bifrostReq.Params.ExtraParams["promptVariables"] = promptVarsMap - } - } + // Create new output index for this tool use + outputIndex := state.CurrentOutputIndex + state.ContentIndexToOutputIndex[contentBlockIndex] = outputIndex + state.CurrentOutputIndex++ // Increment for next use - // Convert request metadata to extra params - if len(request.RequestMetadata) > 0 { - if bifrostReq.Params.ExtraParams == nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) - } - bifrostReq.Params.ExtraParams["requestMetadata"] = request.RequestMetadata - } + // Store tool use ID as item ID and call ID + toolUseID := chunk.Start.ToolUse.ToolUseID + toolName := chunk.Start.ToolUse.Name + state.ItemIDs[outputIndex] = toolUseID + state.ToolCallIDs[outputIndex] = toolUseID + state.ToolCallNames[outputIndex] = toolName - return bifrostReq, nil -} + statusInProgress := "in_progress" + item := &schemas.ResponsesMessage{ + ID: &toolUseID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: &statusInProgress, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &toolUseID, + Name: &toolName, + Arguments: schemas.Ptr(""), // Arguments will be filled by deltas + }, + } -// ToBedrockResponsesRequest converts a BifrostRequest (Responses structure) back to BedrockConverseRequest -func ToBedrockResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*BedrockConverseRequest, error) { - if bifrostReq == nil { - return nil, fmt.Errorf("bifrost request is nil") - } + // Initialize argument buffer for this tool call + state.ToolArgumentBuffers[outputIndex] = "" - bedrockReq := &BedrockConverseRequest{ - ModelID: bifrostReq.Model, - } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(contentBlockIndex), + Item: item, + }) - // map bifrost messages to bedrock messages - if bifrostReq.Input != nil { - messages, systemMessages, err := convertResponsesItemsToBedrockMessages(bifrostReq.Input) - if err != nil { - return nil, fmt.Errorf("failed to convert Responses messages: %w", err) - } - bedrockReq.Messages = messages - if len(systemMessages) > 0 { - bedrockReq.System = systemMessages + return responses, nil, false } - } + // Text content start is handled by Role event, so we can ignore Start for text - // Map basic parameters to inference config - if bifrostReq.Params != nil { - inferenceConfig := &BedrockInferenceConfig{} + case chunk.ContentBlockIndex != nil && chunk.Delta != nil: + // Handle contentBlockDelta event + contentBlockIndex := *chunk.ContentBlockIndex + outputIndex, exists := state.ContentIndexToOutputIndex[contentBlockIndex] + if !exists { + // Check if this is a new content block that should close previous reasoning blocks + var responses []*schemas.BifrostResponsesStreamResponse - if bifrostReq.Params.MaxOutputTokens != nil { - inferenceConfig.MaxTokens = bifrostReq.Params.MaxOutputTokens - } - if bifrostReq.Params.Temperature != nil { - inferenceConfig.Temperature = bifrostReq.Params.Temperature - } - if bifrostReq.Params.TopP != nil { - inferenceConfig.TopP = bifrostReq.Params.TopP - } - if bifrostReq.Params.Reasoning != nil { - if bedrockReq.AdditionalModelRequestFields == nil { - bedrockReq.AdditionalModelRequestFields = make(schemas.OrderedMap) - } - if bifrostReq.Params.Reasoning.Effort != nil && *bifrostReq.Params.Reasoning.Effort == "none" { - bedrockReq.AdditionalModelRequestFields["reasoning_config"] = map[string]string{ - "type": "disabled", - } - } else { - if bifrostReq.Params.Reasoning.MaxTokens == nil { - return nil, fmt.Errorf("reasoning.max_tokens is required for reasoning") - } else if schemas.IsAnthropicModel(bedrockReq.ModelID) && *bifrostReq.Params.Reasoning.MaxTokens < anthropic.MinimumReasoningMaxTokens { - return nil, fmt.Errorf("reasoning.max_tokens must be greater than or equal to %d", anthropic.MinimumReasoningMaxTokens) - } else { - bedrockReq.AdditionalModelRequestFields["reasoning_config"] = map[string]any{ - "type": "enabled", - "budget_tokens": *bifrostReq.Params.Reasoning.MaxTokens, - } - } - } - } - if bifrostReq.Params.ExtraParams != nil { - if stop, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["stop"]); ok { - inferenceConfig.StopSequences = stop - } + // If this is a text delta with a new content block index, close any open reasoning blocks + if chunk.Delta.Text != nil && contentBlockIndex > 0 { + for prevContentIndex := range state.ReasoningContentIndices { + if prevContentIndex < contentBlockIndex { + prevOutputIndex, prevExists := state.ContentIndexToOutputIndex[prevContentIndex] + if !prevExists { + continue + } - if requestFields, exists := bifrostReq.Params.ExtraParams["additionalModelRequestFieldPaths"]; exists { - if orderedFields, ok := schemas.SafeExtractOrderedMap(requestFields); ok { - bedrockReq.AdditionalModelRequestFields = orderedFields - } - } + itemID := state.ItemIDs[prevOutputIndex] - if responseFields, exists := bifrostReq.Params.ExtraParams["additionalModelResponseFieldPaths"]; exists { - if fields, ok := responseFields.([]string); ok { - bedrockReq.AdditionalModelResponseFieldPaths = fields - } else if fieldsInterface, ok := responseFields.([]interface{}); ok { - stringFields := make([]string, 0, len(fieldsInterface)) - for _, field := range fieldsInterface { - if fieldStr, ok := field.(string); ok { - stringFields = append(stringFields, fieldStr) + // For reasoning items, content_index is always 0 + reasoningContentIndex := 0 + + // Emit reasoning_summary_text.done + emptyText := "" + reasoningDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(prevOutputIndex), + ContentIndex: &reasoningContentIndex, + Text: &emptyText, } - } - if len(stringFields) > 0 { - bedrockReq.AdditionalModelResponseFieldPaths = stringFields + if itemID != "" { + reasoningDoneResponse.ItemID = &itemID + } + responses = append(responses, reasoningDoneResponse) + + // Emit content_part.done for reasoning + partDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(prevOutputIndex), + ContentIndex: &reasoningContentIndex, + } + if itemID != "" { + partDoneResponse.ItemID = &itemID + } + responses = append(responses, partDoneResponse) + + // Emit output_item.done for reasoning + statusCompleted := "completed" + doneItem := &schemas.ResponsesMessage{ + Status: &statusCompleted, + } + if itemID != "" { + doneItem.ID = &itemID + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(prevOutputIndex), + ContentIndex: &reasoningContentIndex, + Item: doneItem, + }) + + // Clear the reasoning content index tracking + delete(state.ReasoningContentIndices, prevContentIndex) + + // Mark this output index as completed + state.CompletedOutputIndices[prevOutputIndex] = true } } } - } - - bedrockReq.InferenceConfig = inferenceConfig - if bifrostReq.Params.ServiceTier != nil { - bedrockReq.ServiceTier = &BedrockServiceTier{ - Type: *bifrostReq.Params.ServiceTier, - } - } - } + // Create new output index for this content block + outputIndex = state.CurrentOutputIndex + state.CurrentOutputIndex++ + state.ContentIndexToOutputIndex[contentBlockIndex] = outputIndex - // Convert tools - if bifrostReq.Params != nil && bifrostReq.Params.Tools != nil { - var bedrockTools []BedrockTool - for _, tool := range bifrostReq.Params.Tools { - if tool.ResponsesToolFunction != nil { - // Create the complete schema object that Bedrock expects - var schemaObject interface{} - if tool.ResponsesToolFunction.Parameters != nil { - schemaObject = tool.ResponsesToolFunction.Parameters + // If this is a text delta for a new content block, create the text item + if chunk.Delta.Text != nil { + // Generate stable ID for text item + var itemID string + if state.MessageID == nil { + itemID = fmt.Sprintf("item_%d", outputIndex) } else { - // Fallback to empty object schema if no parameters - schemaObject = map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{}, - } + itemID = fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) } - - if tool.Name == nil || *tool.Name == "" { - return nil, fmt.Errorf("responses tool is missing required name for Bedrock function conversion") + state.ItemIDs[outputIndex] = itemID + + // Create text item + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + item := &schemas.ResponsesMessage{ + ID: &itemID, + Type: &messageType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{}, // Empty blocks slice for mutation support + }, } - name := *tool.Name - // Use the tool description if available, otherwise use a generic description - description := "Function tool" - if tool.Description != nil { - description = *tool.Description + // Emit output_item.added for text + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + Item: item, + }) + + // Emit content_part.added with empty output_text part + emptyText := "" + part := &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &emptyText, } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + ItemID: &itemID, + Part: part, + }) + } - bedrockTool := BedrockTool{ - ToolSpec: &BedrockToolSpec{ - Name: name, - Description: &description, - InputSchema: BedrockToolInputSchema{ - JSON: schemaObject, - }, - }, + // If this is a text delta for a new content block, also emit the text delta in the same batch + if chunk.Delta.Text != nil && *chunk.Delta.Text != "" { + text := *chunk.Delta.Text + itemID := state.ItemIDs[outputIndex] + textDeltaResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + Delta: &text, } - bedrockTools = append(bedrockTools, bedrockTool) + if itemID != "" { + textDeltaResponse.ItemID = &itemID + } + responses = append(responses, textDeltaResponse) } - } - if len(bedrockTools) > 0 { - bedrockReq.ToolConfig = &BedrockToolConfig{ - Tools: bedrockTools, + // If we have responses to return (either from closing reasoning or creating text item), return them first + if len(responses) > 0 { + return responses, nil, false } } - } - // Convert tool choice - if bifrostReq.Params != nil && bifrostReq.Params.ToolChoice != nil { - bedrockToolChoice := convertResponsesToolChoice(*bifrostReq.Params.ToolChoice) - if bedrockToolChoice != nil { - if bedrockReq.ToolConfig == nil { - bedrockReq.ToolConfig = &BedrockToolConfig{} + switch { + case chunk.Delta.Text != nil: + // Handle text delta + text := *chunk.Delta.Text + if text != "" { + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + Delta: &text, + } + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false } - bedrockReq.ToolConfig.ToolChoice = bedrockToolChoice - } - } - // Ensure tool config is present when tool content exists (similar to Chat Completions) - ensureResponsesToolConfigForConversation(bifrostReq, bedrockReq) + case chunk.Delta.ToolUse != nil: + // Handle tool use delta - function call arguments + toolUseDelta := chunk.Delta.ToolUse - return bedrockReq, nil -} + if toolUseDelta.Input != "" { + // Accumulate argument deltas + state.ToolArgumentBuffers[outputIndex] += toolUseDelta.Input -// ensureResponsesToolConfigForConversation ensures toolConfig is present when tool content exists -func ensureResponsesToolConfigForConversation(bifrostReq *schemas.BifrostResponsesRequest, bedrockReq *BedrockConverseRequest) { - if bedrockReq.ToolConfig != nil { - return // Already has tool config - } + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + Delta: &toolUseDelta.Input, + } + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + } - hasToolContent, tools := extractToolsFromResponsesConversationHistory(bifrostReq.Input) - if hasToolContent && len(tools) > 0 { - bedrockReq.ToolConfig = &BedrockToolConfig{Tools: tools} - } -} + case chunk.Delta.ReasoningContent != nil: + // Handle reasoning content delta + reasoningDelta := chunk.Delta.ReasoningContent -// extractToolsFromResponsesConversationHistory extracts tools from Responses conversation history -func extractToolsFromResponsesConversationHistory(messages []schemas.ResponsesMessage) (bool, []BedrockTool) { - var hasToolContent bool - toolMap := make(map[string]*schemas.ResponsesTool) // Use map to deduplicate by name + // Check if this is the first reasoning delta for this content block + if !state.ReasoningContentIndices[contentBlockIndex] { + // First reasoning delta - emit output_item.added and content_part.added + var responses []*schemas.BifrostResponsesStreamResponse - for _, msg := range messages { - // Check if message contains tool use or tool result - if msg.Type != nil { - switch *msg.Type { - case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeFunctionCallOutput: - hasToolContent = true - // Try to infer tool definition from tool call/result - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { - toolName := *msg.ResponsesToolMessage.Name - if _, exists := toolMap[toolName]; !exists { - // Create a minimal tool definition - toolMap[toolName] = &schemas.ResponsesTool{ - Type: "function", - Name: &toolName, - ResponsesToolFunction: &schemas.ResponsesToolFunction{ - Parameters: &schemas.ToolFunctionParameters{ - Type: "object", - Properties: &schemas.OrderedMap{}, - }, - }, - } - } + // Generate stable ID for reasoning item + var itemID string + if state.MessageID == nil { + itemID = fmt.Sprintf("reasoning_%d", outputIndex) + } else { + itemID = fmt.Sprintf("msg_%s_reasoning_%d", *state.MessageID, outputIndex) } - } - } - } - - // Convert map to slice - var tools []BedrockTool - for _, tool := range toolMap { - if tool.Name != nil && tool.ResponsesToolFunction != nil { - schemaObject := tool.ResponsesToolFunction.Parameters - if schemaObject == nil { - schemaObject = &schemas.ToolFunctionParameters{ - Type: "object", - Properties: &schemas.OrderedMap{}, + state.ItemIDs[outputIndex] = itemID + + // Create reasoning item + messageType := schemas.ResponsesMessageTypeReasoning + role := schemas.ResponsesInputMessageRoleAssistant + item := &schemas.ResponsesMessage{ + ID: &itemID, + Type: &messageType, + Role: &role, + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + }, } - } - description := "Function tool" - if tool.Description != nil { - description = *tool.Description - } + // Preserve signature if present + if reasoningDelta.Signature != nil { + item.ResponsesReasoning.EncryptedContent = reasoningDelta.Signature + } - bedrockTool := BedrockTool{ - ToolSpec: &BedrockToolSpec{ - Name: *tool.Name, - Description: &description, - InputSchema: BedrockToolInputSchema{ - JSON: schemaObject, - }, - }, - } - tools = append(tools, bedrockTool) - } - } + // Track that this content index is a reasoning block + state.ReasoningContentIndices[contentBlockIndex] = true - return hasToolContent, tools -} + // Emit output_item.added + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + Item: item, + }) -// ToBifrostResponsesResponse converts BedrockConverseResponse to BifrostResponsesResponse -func (response *BedrockConverseResponse) ToBifrostResponsesResponse() (*schemas.BifrostResponsesResponse, error) { - if response == nil { - return nil, fmt.Errorf("bedrock response is nil") - } + // Emit content_part.added with empty reasoning_text part + emptyText := "" + part := &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: &emptyText, + } + // Preserve signature in the content part if present + if reasoningDelta.Signature != nil { + part.Signature = reasoningDelta.Signature + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + ItemID: &itemID, + Part: part, + }) - bifrostResp := &schemas.BifrostResponsesResponse{ - CreatedAt: int(time.Now().Unix()), - } + // If there's text content, also emit the delta + if reasoningDelta.Text != "" { + deltaResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + Delta: &reasoningDelta.Text, + ItemID: &itemID, + } + responses = append(responses, deltaResponse) + } - // Convert output message to Responses format - if response.Output != nil && response.Output.Message != nil { - outputMessages := convertBedrockMessageToResponsesMessages(*response.Output.Message) - bifrostResp.Output = outputMessages - } + return responses, nil, false + } else { + // Subsequent reasoning deltas - just emit the delta + if reasoningDelta.Text != "" { + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + Delta: &reasoningDelta.Text, + } + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + } - if response.Usage != nil { - // Convert usage information - bifrostResp.Usage = &schemas.ResponsesResponseUsage{ - InputTokens: response.Usage.InputTokens, - OutputTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.TotalTokens, - } - // Handle cached tokens if present - if response.Usage.CacheReadInputTokens > 0 { - bifrostResp.Usage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ - CachedTokens: response.Usage.CacheReadInputTokens, - } - } - if response.Usage.CacheWriteInputTokens > 0 { - bifrostResp.Usage.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{ - CachedTokens: response.Usage.CacheWriteInputTokens, + // Handle signature deltas + if reasoningDelta.Signature != nil { + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentBlockIndex, + Signature: reasoningDelta.Signature, // Use signature field instead of delta + } + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + } } } - } - if response.ServiceTier != nil && response.ServiceTier.Type != "" { - bifrostResp.ServiceTier = &response.ServiceTier.Type + case chunk.StopReason != nil: + // Stop reason - track it for the final response + var stopReason string + switch *chunk.StopReason { + case "tool_use": + stopReason = "tool_calls" + case "end_turn": + stopReason = "stop" + case "max_tokens": + stopReason = "length" + default: + stopReason = *chunk.StopReason + } + state.StopReason = &stopReason + // Items should be closed explicitly when content blocks end + return nil, nil, false } - return bifrostResp, nil + return nil, nil, false } -// Helper functions +// FinalizeBedrockStream finalizes the stream by closing any open items and emitting completed event +func FinalizeBedrockStream(state *BedrockResponsesStreamState, sequenceNumber int, usage *schemas.ResponsesResponseUsage) []*schemas.BifrostResponsesStreamResponse { + var responses []*schemas.BifrostResponsesStreamResponse -func convertResponsesToolChoice(toolChoice schemas.ResponsesToolChoice) *BedrockToolChoice { - // Check if it's a string choice - if toolChoice.ResponsesToolChoiceStr != nil { - switch schemas.ResponsesToolChoiceType(*toolChoice.ResponsesToolChoiceStr) { - case schemas.ResponsesToolChoiceTypeAny, schemas.ResponsesToolChoiceTypeRequired: - return &BedrockToolChoice{ - Any: &BedrockToolChoiceAny{}, - } - case schemas.ResponsesToolChoiceTypeNone: - // Bedrock doesn't have explicit "none" - just don't include tools - return nil + // Close any open items (text items and tool calls) + for contentIndex, outputIndex := range state.ContentIndexToOutputIndex { + // Skip reasoning blocks + if state.ReasoningContentIndices[contentIndex] { + continue } - } - // Check if it's a struct choice - if toolChoice.ResponsesToolChoiceStruct != nil { - switch toolChoice.ResponsesToolChoiceStruct.Type { - case schemas.ResponsesToolChoiceTypeFunction: - // Extract the actual function name from the struct - if toolChoice.ResponsesToolChoiceStruct.Name != nil && *toolChoice.ResponsesToolChoiceStruct.Name != "" { - return &BedrockToolChoice{ - Tool: &BedrockToolChoiceTool{ - Name: *toolChoice.ResponsesToolChoiceStruct.Name, - }, - } - } - // If Name is nil or empty, return nil as we can't construct a valid tool choice - return nil - case schemas.ResponsesToolChoiceTypeAuto, schemas.ResponsesToolChoiceTypeAny, schemas.ResponsesToolChoiceTypeRequired: - return &BedrockToolChoice{ - Any: &BedrockToolChoiceAny{}, - } - case schemas.ResponsesToolChoiceTypeNone: - return nil + // Skip already completed output indices + if state.CompletedOutputIndices[outputIndex] { + continue } - } - return nil -} + itemID := state.ItemIDs[outputIndex] + if itemID == "" { + continue + } -// convertResponsesItemsToBedrockMessages converts Responses items back to Bedrock messages -func convertResponsesItemsToBedrockMessages(messages []schemas.ResponsesMessage) ([]BedrockMessage, []BedrockSystemMessage, error) { - var bedrockMessages []BedrockMessage - var systemMessages []BedrockSystemMessage + // Check if this is a tool call by looking at the tool call IDs + toolCallID := state.ToolCallIDs[outputIndex] + isToolCall := toolCallID != "" - for _, msg := range messages { - // Handle Responses items - msgType := schemas.ResponsesMessageTypeMessage - if msg.Type != nil { - msgType = *msg.Type - } - switch msgType { - case schemas.ResponsesMessageTypeMessage: - // Check if Role is present, skip message if not - if msg.Role == nil { - continue - } + if isToolCall { + // This is a tool call that needs to be closed - // Extract role from the Responses message structure - role := *msg.Role + // Emit content_part.done for tool call + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentIndex, + ItemID: &itemID, + }) - if role == schemas.ResponsesInputMessageRoleSystem { - // Convert to system message - // Ensure Content and ContentStr are present - if msg.Content != nil { - if msg.Content.ContentStr != nil { - systemMessages = append(systemMessages, BedrockSystemMessage{ - Text: msg.Content.ContentStr, - }) - } else if msg.Content.ContentBlocks != nil { - for _, block := range msg.Content.ContentBlocks { - if block.Text != nil { - systemMessages = append(systemMessages, BedrockSystemMessage{ - Text: block.Text, - }) - } - } + // Emit function_call_arguments.done with full arguments + toolName := state.ToolCallNames[outputIndex] + accumulatedArgs := state.ToolArgumentBuffers[outputIndex] + if accumulatedArgs != "" { + var doneItem *schemas.ResponsesMessage + if toolCallID != "" || toolName != "" { + doneItem = &schemas.ResponsesMessage{ + ResponsesToolMessage: &schemas.ResponsesToolMessage{}, + } + if toolCallID != "" { + doneItem.ResponsesToolMessage.CallID = &toolCallID + } + if toolName != "" { + doneItem.ResponsesToolMessage.Name = &toolName } - } - // Skip system messages with no content - } else { - // Convert regular message - // Ensure Content is present - if msg.Content == nil { - // Skip messages without content or create with empty content - continue } - bedrockMsg := BedrockMessage{ - Role: BedrockMessageRole(role), + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + Arguments: &accumulatedArgs, } - - // Convert content - contentBlocks, err := convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(*msg.Content) - if err != nil { - return nil, nil, fmt.Errorf("failed to convert content blocks: %w", err) + if itemID != "" { + response.ItemID = &itemID } - bedrockMsg.Content = contentBlocks - - bedrockMessages = append(bedrockMessages, bedrockMsg) - } - - case schemas.ResponsesMessageTypeFunctionCall: - // Handle function calls from Responses - if msg.ResponsesToolMessage != nil { - // Create tool use content block - var toolUseID string - if msg.ResponsesToolMessage.CallID != nil { - toolUseID = *msg.ResponsesToolMessage.CallID + if doneItem != nil { + response.Item = doneItem } + responses = append(responses, response) + } - // Get function name from ToolMessage - var functionName string - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { - functionName = *msg.ResponsesToolMessage.Name - } + // Emit output_item.done for tool call + statusCompleted := "completed" + doneItem := &schemas.ResponsesMessage{ + ID: &itemID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: &statusCompleted, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &toolCallID, + Name: &toolName, + Arguments: &accumulatedArgs, + }, + } - // Parse JSON arguments into interface{} - var input interface{} = map[string]interface{}{} - if msg.ResponsesToolMessage.Arguments != nil { - var parsedInput interface{} - if err := json.Unmarshal([]byte(*msg.ResponsesToolMessage.Arguments), &parsedInput); err != nil { - return nil, nil, fmt.Errorf("failed to parse tool arguments JSON: %w", err) - } - input = parsedInput - } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentIndex, + ItemID: &itemID, + Item: doneItem, + }) + } else { + // This is likely a text item that needs to be closed - toolUseBlock := BedrockContentBlock{ - ToolUse: &BedrockToolUse{ - ToolUseID: toolUseID, - Name: functionName, - Input: input, - }, - } + // Emit output_text.done (without accumulated text, just the event) + emptyText := "" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentIndex, + ItemID: &itemID, + Text: &emptyText, + }) - // Create assistant message with tool use - assistantMsg := BedrockMessage{ - Role: BedrockMessageRoleAssistant, - Content: []BedrockContentBlock{toolUseBlock}, - } - bedrockMessages = append(bedrockMessages, assistantMsg) + // Emit content_part.done for text + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentIndex, + ItemID: &itemID, + }) + // Emit output_item.done for text + statusCompleted := "completed" + doneItem := &schemas.ResponsesMessage{ + Status: &statusCompleted, } + if itemID != "" { + doneItem.ID = &itemID + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &contentIndex, + Item: doneItem, + }) + } - case schemas.ResponsesMessageTypeFunctionCallOutput: - // Handle function call outputs from Responses - if msg.ResponsesToolMessage != nil { - // Check if we have output or error - hasOutput := msg.ResponsesToolMessage.Output != nil && - (msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil || - msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil) - hasError := msg.ResponsesToolMessage.Error != nil - - if hasOutput || hasError { - var toolUseID string - if msg.ResponsesToolMessage.CallID != nil { - toolUseID = *msg.ResponsesToolMessage.CallID - } - - status := "success" - if hasError { - status = "error" - } + // Mark this output index as completed + state.CompletedOutputIndices[outputIndex] = true + } - toolResultBlock := BedrockContentBlock{ - ToolResult: &BedrockToolResult{ - ToolUseID: toolUseID, - Status: schemas.Ptr(status), - }, + // Close any open reasoning items + for contentIndex := range state.ReasoningContentIndices { + outputIndex, exists := state.ContentIndexToOutputIndex[contentIndex] + if !exists { + continue + } + + // Skip already completed output indices + if state.CompletedOutputIndices[outputIndex] { + continue + } + + itemID := state.ItemIDs[outputIndex] + + // For reasoning items, content_index is always 0 (reasoning content is the first and only content part) + reasoningContentIndex := 0 + + // Emit reasoning_summary_text.done + emptyText := "" + reasoningDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &reasoningContentIndex, + Text: &emptyText, + } + if itemID != "" { + reasoningDoneResponse.ItemID = &itemID + } + responses = append(responses, reasoningDoneResponse) + + // Emit content_part.done for reasoning + partDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &reasoningContentIndex, + } + if itemID != "" { + partDoneResponse.ItemID = &itemID + } + responses = append(responses, partDoneResponse) + + // Emit output_item.done for reasoning + statusCompleted := "completed" + doneItem := &schemas.ResponsesMessage{ + Status: &statusCompleted, + } + if itemID != "" { + doneItem.ID = &itemID + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: &reasoningContentIndex, + Item: doneItem, + }) + + // Mark this output index as completed + state.CompletedOutputIndices[outputIndex] = true + } + + // Note: Tool calls are already closed in the first loop above. + // This section is intentionally left empty to avoid duplicate events. + + // Emit response.completed + response := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, + Usage: usage, + } + + if state.Model != nil { + response.Model = *state.Model + } + if state.StopReason != nil { + response.StopReason = state.StopReason + } else { + // Infer stop reason based on whether tool calls are present + hasToolCalls := false + for _, toolCallID := range state.ToolCallIDs { + if toolCallID != "" { + hasToolCalls = true + break + } + } + if hasToolCalls { + response.StopReason = schemas.Ptr("tool_calls") + } else { + response.StopReason = schemas.Ptr("stop") + } + } + + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCompleted, + SequenceNumber: sequenceNumber + len(responses), + Response: response, + }) + + return responses +} + +// ToBedrockConverseStreamResponse converts a Bifrost Responses stream response to Bedrock streaming format +// Returns a BedrockStreamEvent that represents the streaming event in Bedrock's format +func ToBedrockConverseStreamResponse(bifrostResp *schemas.BifrostResponsesStreamResponse) (*BedrockStreamEvent, error) { + if bifrostResp == nil { + return nil, fmt.Errorf("bifrost stream response is nil") + } + + event := &BedrockStreamEvent{} + + switch bifrostResp.Type { + case schemas.ResponsesStreamResponseTypeCreated: + // Message start - emit role event + // Always set role for message start event + role := "assistant" + event.Role = &role + + case schemas.ResponsesStreamResponseTypeInProgress: + // In progress - no-op for Bedrock (it doesn't have an explicit in_progress event) + // Return nil to skip this event + return nil, nil + + case schemas.ResponsesStreamResponseTypeOutputItemAdded: + // Content block start + if bifrostResp.Item != nil && bifrostResp.Item.ResponsesToolMessage != nil { + // Tool use start + if bifrostResp.Item.ResponsesToolMessage.Name != nil && bifrostResp.Item.ResponsesToolMessage.CallID != nil { + contentBlockIndex := 0 + if bifrostResp.ContentIndex != nil { + contentBlockIndex = *bifrostResp.ContentIndex + } + event.ContentBlockIndex = &contentBlockIndex + event.Start = &BedrockContentBlockStart{ + ToolUse: &BedrockToolUseStart{ + ToolUseID: *bifrostResp.Item.ResponsesToolMessage.CallID, + Name: *bifrostResp.Item.ResponsesToolMessage.Name, + }, + } + } + } else if bifrostResp.Item != nil { + // Text item added - Bedrock doesn't have an explicit text start event, so we skip it + // Check if it's a text message (has content blocks or is a message type) + if bifrostResp.Item.Content != nil || (bifrostResp.Item.Type != nil && *bifrostResp.Item.Type == schemas.ResponsesMessageTypeMessage) { + return nil, nil + } + } + + case schemas.ResponsesStreamResponseTypeOutputTextDelta: + // Text delta + if bifrostResp.Delta != nil && *bifrostResp.Delta != "" { + contentBlockIndex := 0 + if bifrostResp.ContentIndex != nil { + contentBlockIndex = *bifrostResp.ContentIndex + } + event.ContentBlockIndex = &contentBlockIndex + event.Delta = &BedrockContentBlockDelta{ + Text: bifrostResp.Delta, + } + } else { + // Skip empty deltas + return nil, nil + } + + case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta: + // Tool use delta (function call arguments) + if bifrostResp.Delta != nil { + contentBlockIndex := 0 + if bifrostResp.ContentIndex != nil { + contentBlockIndex = *bifrostResp.ContentIndex + } + event.ContentBlockIndex = &contentBlockIndex + event.Delta = &BedrockContentBlockDelta{ + ToolUse: &BedrockToolUseDelta{ + Input: *bifrostResp.Delta, + }, + } + } + + case schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta: + // Reasoning content delta + contentBlockIndex := 0 + if bifrostResp.ContentIndex != nil { + contentBlockIndex = *bifrostResp.ContentIndex + } + event.ContentBlockIndex = &contentBlockIndex + + // Check if this is a signature delta or text delta + if bifrostResp.Signature != nil { + // This is a signature delta + event.Delta = &BedrockContentBlockDelta{ + ReasoningContent: &BedrockReasoningContentText{ + Signature: bifrostResp.Signature, + }, + } + } else if bifrostResp.Delta != nil && *bifrostResp.Delta != "" { + // This is reasoning text delta + event.Delta = &BedrockContentBlockDelta{ + ReasoningContent: &BedrockReasoningContentText{ + Text: *bifrostResp.Delta, + }, + } + } else { + // Skip empty deltas + return nil, nil + } + + case schemas.ResponsesStreamResponseTypeOutputTextDone, + schemas.ResponsesStreamResponseTypeContentPartDone, + schemas.ResponsesStreamResponseTypeReasoningSummaryTextDone: + // Content block done - Bedrock doesn't have explicit done events, so we skip them + return nil, nil + + case schemas.ResponsesStreamResponseTypeOutputItemDone: + // Item done - Bedrock doesn't have explicit done events, so we skip them + return nil, nil + + case schemas.ResponsesStreamResponseTypeCompleted: + // Message stop - always set stopReason + stopReason := "end_turn" + if bifrostResp.Response != nil && bifrostResp.Response.IncompleteDetails != nil { + stopReason = bifrostResp.Response.IncompleteDetails.Reason + } + event.StopReason = &stopReason + + // Add usage if available + if bifrostResp.Response != nil && bifrostResp.Response.Usage != nil { + event.Usage = &BedrockTokenUsage{ + InputTokens: bifrostResp.Response.Usage.InputTokens, + OutputTokens: bifrostResp.Response.Usage.OutputTokens, + TotalTokens: bifrostResp.Response.Usage.TotalTokens, + } + } + + case schemas.ResponsesStreamResponseTypeError: + // Error - errors are handled separately by the router via BifrostError in the stream chunk + // Return nil to skip this chunk + return nil, nil + + default: + // Unknown type - skip + return nil, nil + } + + return event, nil +} + +// BedrockEncodedEvent represents a single event ready for encoding to AWS Event Stream +type BedrockEncodedEvent struct { + EventType string + Payload interface{} +} + +// BedrockInvokeStreamChunkEvent represents the chunk event for invoke-with-response-stream +type BedrockInvokeStreamChunkEvent struct { + Bytes []byte `json:"bytes"` +} + +// ToEncodedEvents converts the flat BedrockStreamEvent into a sequence of specific events +func (event *BedrockStreamEvent) ToEncodedEvents() []BedrockEncodedEvent { + var events []BedrockEncodedEvent + + if event.InvokeModelRawChunk != nil { + events = append(events, BedrockEncodedEvent{ + EventType: "chunk", + Payload: BedrockInvokeStreamChunkEvent{ + Bytes: event.InvokeModelRawChunk, + }, + }) + } + + if event.Role != nil { + events = append(events, BedrockEncodedEvent{ + EventType: "messageStart", + Payload: BedrockMessageStartEvent{ + Role: *event.Role, + }, + }) + } + + if event.Start != nil { + events = append(events, BedrockEncodedEvent{ + EventType: "contentBlockStart", + Payload: struct { + Start *BedrockContentBlockStart `json:"start"` + ContentBlockIndex *int `json:"contentBlockIndex"` + }{ + Start: event.Start, + ContentBlockIndex: event.ContentBlockIndex, + }, + }) + } + + if event.Delta != nil { + events = append(events, BedrockEncodedEvent{ + EventType: "contentBlockDelta", + Payload: struct { + Delta *BedrockContentBlockDelta `json:"delta"` + ContentBlockIndex *int `json:"contentBlockIndex"` + }{ + Delta: event.Delta, + ContentBlockIndex: event.ContentBlockIndex, + }, + }) + } + + if event.StopReason != nil { + events = append(events, BedrockEncodedEvent{ + EventType: "messageStop", + Payload: BedrockMessageStopEvent{ + StopReason: *event.StopReason, + }, + }) + } + + if event.Usage != nil || event.Metrics != nil { + events = append(events, BedrockEncodedEvent{ + EventType: "metadata", + Payload: BedrockMetadataEvent{ + Usage: event.Usage, + Metrics: event.Metrics, + Trace: event.Trace, + }, + }) + } + + return events +} + +// ToBifrostResponsesRequest converts a BedrockConverseRequest to Bifrost Responses Request format +func (request *BedrockConverseRequest) ToBifrostResponsesRequest() (*schemas.BifrostResponsesRequest, error) { + if request == nil { + return nil, fmt.Errorf("bedrock request is nil") + } + + // Extract provider from model ID (format: "bedrock/model-name") + provider, model := schemas.ParseModelString(request.ModelID, schemas.Bedrock) + + bifrostReq := &schemas.BifrostResponsesRequest{ + Provider: provider, + Model: model, + Params: &schemas.ResponsesParameters{}, + Fallbacks: schemas.ParseFallbacks(request.Fallbacks), + } + + // Convert messages using the new conversion method + convertedMessages := ConvertBedrockMessagesToBifrostMessages(request.Messages, request.System, false) + bifrostReq.Input = convertedMessages + + // Convert inference config to parameters + if request.InferenceConfig != nil { + if request.InferenceConfig.MaxTokens != nil { + bifrostReq.Params.MaxOutputTokens = request.InferenceConfig.MaxTokens + } + if request.InferenceConfig.Temperature != nil { + bifrostReq.Params.Temperature = request.InferenceConfig.Temperature + } + if request.InferenceConfig.TopP != nil { + bifrostReq.Params.TopP = request.InferenceConfig.TopP + } + if len(request.InferenceConfig.StopSequences) > 0 { + if bifrostReq.Params.ExtraParams == nil { + bifrostReq.Params.ExtraParams = make(map[string]interface{}) + } + bifrostReq.Params.ExtraParams["stop"] = request.InferenceConfig.StopSequences + } + } + + // Convert tool config + if request.ToolConfig != nil && len(request.ToolConfig.Tools) > 0 { + for _, tool := range request.ToolConfig.Tools { + if tool.ToolSpec != nil { + bifrostTool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeFunction, + Name: &tool.ToolSpec.Name, + Description: tool.ToolSpec.Description, + ResponsesToolFunction: &schemas.ResponsesToolFunction{}, + } + + // Handle different types for InputSchema.JSON + if params, ok := tool.ToolSpec.InputSchema.JSON.(*schemas.ToolFunctionParameters); ok { + bifrostTool.ResponsesToolFunction.Parameters = params + } else if paramsMap, ok := tool.ToolSpec.InputSchema.JSON.(map[string]interface{}); ok { + // Convert map to ToolFunctionParameters + params := &schemas.ToolFunctionParameters{} + if typeVal, ok := paramsMap["type"].(string); ok { + params.Type = typeVal + } + // Handle both pointer and non-pointer properties + if props, ok := schemas.SafeExtractOrderedMap(paramsMap["properties"]); ok { + params.Properties = &props + } + if required, ok := paramsMap["required"].([]interface{}); ok { + reqStrings := make([]string, 0, len(required)) + for _, r := range required { + if rStr, ok := r.(string); ok { + reqStrings = append(reqStrings, rStr) + } + } + params.Required = reqStrings + } else if required, ok := paramsMap["required"].([]string); ok { + params.Required = required } + bifrostTool.ResponsesToolFunction.Parameters = params + } + + bifrostReq.Params.Tools = append(bifrostReq.Params.Tools, bifrostTool) + } + } + } + + // Convert guardrail config to extra params + if request.GuardrailConfig != nil { + if bifrostReq.Params.ExtraParams == nil { + bifrostReq.Params.ExtraParams = make(map[string]interface{}) + } - // Set content based on available data - if hasError { - toolResultBlock.ToolResult.Content = []BedrockContentBlock{ - {Text: msg.ResponsesToolMessage.Error}, + guardrailMap := map[string]interface{}{ + "guardrailIdentifier": request.GuardrailConfig.GuardrailIdentifier, + "guardrailVersion": request.GuardrailConfig.GuardrailVersion, + } + if request.GuardrailConfig.Trace != nil { + guardrailMap["trace"] = *request.GuardrailConfig.Trace + } + bifrostReq.Params.ExtraParams["guardrailConfig"] = guardrailMap + } + + // Convert additional model request fields to extra params + if len(request.AdditionalModelRequestFields) > 0 { + reasoningConfig, ok := schemas.SafeExtractFromMap(request.AdditionalModelRequestFields, "reasoning_config") + if ok { + if reasoningConfigMap, ok := reasoningConfig.(map[string]interface{}); ok { + if typeStr, ok := schemas.SafeExtractString(reasoningConfigMap["type"]); ok { + if typeStr == "enabled" { + var summary *string + if summaryValue, ok := schemas.SafeExtractStringPointer(request.ExtraParams["reasoning_summary"]); ok { + summary = summaryValue } - } else if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { - raw := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr - var parsed interface{} - if err := json.Unmarshal([]byte(raw), &parsed); err == nil { - toolResultBlock.ToolResult.Content = []BedrockContentBlock{ - {JSON: parsed}, + if maxTokens, ok := schemas.SafeExtractInt(reasoningConfigMap["budget_tokens"]); ok { + minBudgetTokens := 0 + if schemas.IsAnthropicModel(bifrostReq.Model) { + minBudgetTokens = anthropic.MinimumReasoningMaxTokens } - } else { - toolResultBlock.ToolResult.Content = []BedrockContentBlock{ - {Text: &raw}, + effort := providerUtils.GetReasoningEffortFromBudgetTokens(maxTokens, minBudgetTokens, *request.InferenceConfig.MaxTokens) + bifrostReq.Params.Reasoning = &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr(effort), + MaxTokens: schemas.Ptr(maxTokens), + Summary: summary, } } - } else if msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { - toolResultContent, err := convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(schemas.ResponsesMessageContent{ - ContentBlocks: msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks, - }) - if err != nil { - return nil, nil, fmt.Errorf("failed to convert tool result content blocks: %w", err) + } else { + bifrostReq.Params.Reasoning = &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr("none"), } - toolResultBlock.ToolResult.Content = toolResultContent - } - - // Create user message with tool result - userMsg := BedrockMessage{ - Role: BedrockMessageRoleUser, - Content: []BedrockContentBlock{toolResultBlock}, } - bedrockMessages = append(bedrockMessages, userMsg) } } } } - return bedrockMessages, systemMessages, nil + if include, ok := schemas.SafeExtractStringSlice(request.ExtraParams["include"]); ok { + bifrostReq.Params.Include = include + } + + // Convert performance config to extra params + if request.PerformanceConfig != nil { + if bifrostReq.Params.ExtraParams == nil { + bifrostReq.Params.ExtraParams = make(map[string]interface{}) + } + + perfConfigMap := map[string]interface{}{} + if request.PerformanceConfig.Latency != nil { + perfConfigMap["latency"] = *request.PerformanceConfig.Latency + } + if len(perfConfigMap) > 0 { + bifrostReq.Params.ExtraParams["performanceConfig"] = perfConfigMap + } + } + + // Convert prompt variables to extra params + if len(request.PromptVariables) > 0 { + if bifrostReq.Params.ExtraParams == nil { + bifrostReq.Params.ExtraParams = make(map[string]interface{}) + } + + promptVarsMap := make(map[string]interface{}) + for key, value := range request.PromptVariables { + varMap := map[string]interface{}{} + if value.Text != nil { + varMap["text"] = *value.Text + } + if len(varMap) > 0 { + promptVarsMap[key] = varMap + } + } + if len(promptVarsMap) > 0 { + bifrostReq.Params.ExtraParams["promptVariables"] = promptVarsMap + } + } + + // Convert request metadata to extra params + if len(request.RequestMetadata) > 0 { + if bifrostReq.Params.ExtraParams == nil { + bifrostReq.Params.ExtraParams = make(map[string]interface{}) + } + bifrostReq.Params.ExtraParams["requestMetadata"] = request.RequestMetadata + } + + // Convert additional model request fields to extra params + if len(request.AdditionalModelRequestFields) > 0 { + if bifrostReq.Params.ExtraParams == nil { + bifrostReq.Params.ExtraParams = make(map[string]interface{}) + } + bifrostReq.Params.ExtraParams["additionalModelRequestFieldPaths"] = request.AdditionalModelRequestFields + } + + // Convert additional model response field paths to extra params + if len(request.AdditionalModelResponseFieldPaths) > 0 { + if bifrostReq.Params.ExtraParams == nil { + bifrostReq.Params.ExtraParams = make(map[string]interface{}) + } + bifrostReq.Params.ExtraParams["additionalModelResponseFieldPaths"] = request.AdditionalModelResponseFieldPaths + } + + return bifrostReq, nil } -// convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks converts Bifrost content to Bedrock content blocks -func convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(content schemas.ResponsesMessageContent) ([]BedrockContentBlock, error) { - var blocks []BedrockContentBlock +// ToBedrockResponsesRequest converts a BifrostRequest (Responses structure) back to BedrockConverseRequest +func ToBedrockResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*BedrockConverseRequest, error) { + if bifrostReq == nil { + return nil, fmt.Errorf("bifrost request is nil") + } - if content.ContentStr != nil { - blocks = append(blocks, BedrockContentBlock{ - Text: content.ContentStr, - }) - } else if content.ContentBlocks != nil { - for _, block := range content.ContentBlocks { + bedrockReq := &BedrockConverseRequest{ + ModelID: bifrostReq.Model, + } - bedrockBlock := BedrockContentBlock{} + // map bifrost messages to bedrock messages using the new conversion method + if bifrostReq.Input != nil { + messages, systemMessages, err := ConvertBifrostMessagesToBedrockMessages(bifrostReq.Input) + if err != nil { + return nil, fmt.Errorf("failed to convert Responses messages: %w", err) + } + bedrockReq.Messages = messages + if len(systemMessages) > 0 { + bedrockReq.System = systemMessages + } + } - switch block.Type { - case schemas.ResponsesInputMessageContentBlockTypeText, schemas.ResponsesOutputMessageContentTypeText: - bedrockBlock.Text = block.Text - case schemas.ResponsesInputMessageContentBlockTypeImage: - if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { - imageSource, err := convertImageToBedrockSource(*block.ResponsesInputMessageContentBlockImage.ImageURL) + // Map basic parameters to inference config + if bifrostReq.Params != nil { + inferenceConfig := &BedrockInferenceConfig{} + + if bifrostReq.Params.MaxOutputTokens != nil { + inferenceConfig.MaxTokens = bifrostReq.Params.MaxOutputTokens + } + if bifrostReq.Params.Temperature != nil { + inferenceConfig.Temperature = bifrostReq.Params.Temperature + } + if bifrostReq.Params.TopP != nil { + inferenceConfig.TopP = bifrostReq.Params.TopP + } + if bifrostReq.Params.Reasoning != nil { + if bedrockReq.AdditionalModelRequestFields == nil { + bedrockReq.AdditionalModelRequestFields = make(schemas.OrderedMap) + } + if bifrostReq.Params.Reasoning.MaxTokens != nil { + if schemas.IsAnthropicModel(bifrostReq.Model) && *bifrostReq.Params.Reasoning.MaxTokens < anthropic.MinimumReasoningMaxTokens { + return nil, fmt.Errorf("reasoning.max_tokens must be >= %d for anthropic", anthropic.MinimumReasoningMaxTokens) + } + bedrockReq.AdditionalModelRequestFields["reasoning_config"] = map[string]any{ + "type": "enabled", + "budget_tokens": *bifrostReq.Params.Reasoning.MaxTokens, + } + } else { + if bifrostReq.Params.Reasoning.Effort != nil && *bifrostReq.Params.Reasoning.Effort != "none" { + minBudgetTokens := MinimumReasoningMaxTokens + if schemas.IsAnthropicModel(bifrostReq.Model) { + minBudgetTokens = anthropic.MinimumReasoningMaxTokens + } + defaultMaxTokens := DefaultCompletionMaxTokens + if inferenceConfig.MaxTokens != nil { + defaultMaxTokens = *inferenceConfig.MaxTokens + } else { + inferenceConfig.MaxTokens = schemas.Ptr(DefaultCompletionMaxTokens) + } + budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, minBudgetTokens, defaultMaxTokens) if err != nil { - return nil, fmt.Errorf("failed to convert image in responses content block: %w", err) + return nil, err + } + bedrockReq.AdditionalModelRequestFields["reasoning_config"] = map[string]any{ + "type": "enabled", + "budget_tokens": budgetTokens, + } + } else { + bedrockReq.AdditionalModelRequestFields["reasoning_config"] = map[string]string{ + "type": "disabled", + } + } + } + } + if bifrostReq.Params.ExtraParams != nil { + if stop, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["stop"]); ok { + inferenceConfig.StopSequences = stop + } + + if requestFields, exists := bifrostReq.Params.ExtraParams["additionalModelRequestFieldPaths"]; exists { + if orderedFields, ok := schemas.SafeExtractOrderedMap(requestFields); ok { + bedrockReq.AdditionalModelRequestFields = orderedFields + } + } + + if responseFields, exists := bifrostReq.Params.ExtraParams["additionalModelResponseFieldPaths"]; exists { + if fields, ok := responseFields.([]string); ok { + bedrockReq.AdditionalModelResponseFieldPaths = fields + } else if fieldsInterface, ok := responseFields.([]interface{}); ok { + stringFields := make([]string, 0, len(fieldsInterface)) + for _, field := range fieldsInterface { + if fieldStr, ok := field.(string); ok { + stringFields = append(stringFields, fieldStr) + } + } + if len(stringFields) > 0 { + bedrockReq.AdditionalModelResponseFieldPaths = stringFields } - bedrockBlock.Image = imageSource } - default: - // Don't add anything } + } - blocks = append(blocks, bedrockBlock) + bedrockReq.InferenceConfig = inferenceConfig + + if bifrostReq.Params.ServiceTier != nil { + bedrockReq.ServiceTier = &BedrockServiceTier{ + Type: *bifrostReq.Params.ServiceTier, + } } } - return blocks, nil -} + // Convert tools + if bifrostReq.Params != nil && bifrostReq.Params.Tools != nil { + var bedrockTools []BedrockTool + for _, tool := range bifrostReq.Params.Tools { + if tool.ResponsesToolFunction != nil { + // Create the complete schema object that Bedrock expects + var schemaObject interface{} + if tool.ResponsesToolFunction.Parameters != nil { + schemaObject = tool.ResponsesToolFunction.Parameters + } else { + // Fallback to empty object schema if no parameters + schemaObject = map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } + } -// convertBedrockMessageToResponsesMessages converts Bedrock message to ChatMessage output format -func convertBedrockMessageToResponsesMessages(bedrockMsg BedrockMessage) []schemas.ResponsesMessage { - var outputMessages []schemas.ResponsesMessage + if tool.Name == nil || *tool.Name == "" { + return nil, fmt.Errorf("responses tool is missing required name for Bedrock function conversion") + } + name := *tool.Name - for _, block := range bedrockMsg.Content { - if block.Text != nil { - // Text content - outputMessages = append(outputMessages, schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: block.Text, + // Use the tool description if available, otherwise use a generic description + description := "Function tool" + if tool.Description != nil { + description = *tool.Description + } + + bedrockTool := BedrockTool{ + ToolSpec: &BedrockToolSpec{ + Name: name, + Description: &description, + InputSchema: BedrockToolInputSchema{ + JSON: schemaObject, }, }, - }, - }) - } else if block.ToolUse != nil { - // Tool use content - // Create copies of the values to avoid range loop variable capture - toolUseID := block.ToolUse.ToolUseID - toolUseName := block.ToolUse.Name + } + bedrockTools = append(bedrockTools, bedrockTool) + } + } - toolMsg := schemas.ResponsesMessage{ - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - Status: schemas.Ptr("completed"), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: &toolUseID, - Name: &toolUseName, - Arguments: schemas.Ptr(schemas.JsonifyInput(block.ToolUse.Input)), - }, + if len(bedrockTools) > 0 { + bedrockReq.ToolConfig = &BedrockToolConfig{ + Tools: bedrockTools, } - outputMessages = append(outputMessages, toolMsg) - } else if block.ToolResult != nil { - // Tool result content - typically not in assistant output but handled for completeness - // Prefer JSON payloads without unmarshalling; fallback to text - var resultContent string - if len(block.ToolResult.Content) > 0 { - // JSON first (no unmarshal; just one marshal to string when present) - for _, c := range block.ToolResult.Content { - if c.JSON != nil { - resultContent = schemas.JsonifyInput(c.JSON) - break - } - } - // Fallback to first available text block - if resultContent == "" { - for _, c := range block.ToolResult.Content { - if c.Text != nil { - resultContent = *c.Text - break - } - } - } + } + } + + // Convert tool choice + if bifrostReq.Params != nil && bifrostReq.Params.ToolChoice != nil { + bedrockToolChoice := convertResponsesToolChoice(*bifrostReq.Params.ToolChoice) + if bedrockToolChoice != nil { + if bedrockReq.ToolConfig == nil { + bedrockReq.ToolConfig = &BedrockToolConfig{} } + bedrockReq.ToolConfig.ToolChoice = bedrockToolChoice + } + } - // Create a copy of the value to avoid range loop variable capture - toolResultID := block.ToolResult.ToolUseID + // Ensure tool config is present when tool content exists (similar to Chat Completions) + ensureResponsesToolConfigForConversation(bifrostReq, bedrockReq) - resultMsg := schemas.ResponsesMessage{ - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: &resultContent, - }, - }, - }, - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: &toolResultID, - Output: &schemas.ResponsesToolMessageOutputStruct{ - ResponsesToolCallOutputStr: &resultContent, - }, - }, + return bedrockReq, nil +} + +// ToBifrostResponsesResponse converts BedrockConverseResponse to BifrostResponsesResponse +func (response *BedrockConverseResponse) ToBifrostResponsesResponse() (*schemas.BifrostResponsesResponse, error) { + if response == nil { + return nil, fmt.Errorf("bedrock response is nil") + } + + bifrostResp := &schemas.BifrostResponsesResponse{ + ID: schemas.Ptr(uuid.New().String()), + CreatedAt: int(time.Now().Unix()), + } + + // Convert output message to Responses format using the new conversion method + if response.Output != nil && response.Output.Message != nil { + outputMessages := ConvertBedrockMessagesToBifrostMessages([]BedrockMessage{*response.Output.Message}, []BedrockSystemMessage{}, true) + bifrostResp.Output = outputMessages + } + + if response.Usage != nil { + // Convert usage information + bifrostResp.Usage = &schemas.ResponsesResponseUsage{ + InputTokens: response.Usage.InputTokens, + OutputTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.TotalTokens, + } + // Handle cached tokens if present + if response.Usage.CacheReadInputTokens > 0 { + bifrostResp.Usage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ + CachedTokens: response.Usage.CacheReadInputTokens, + } + } + if response.Usage.CacheWriteInputTokens > 0 { + bifrostResp.Usage.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{ + CachedTokens: response.Usage.CacheWriteInputTokens, } - outputMessages = append(outputMessages, resultMsg) } } - return outputMessages + if response.ServiceTier != nil && response.ServiceTier.Type != "" { + bifrostResp.ServiceTier = &response.ServiceTier.Type + } + + return bifrostResp, nil } // ToBedrockConverseResponse converts Bifrost Responses response to Bedrock Converse response @@ -1047,699 +1638,1027 @@ func ToBedrockConverseResponse(bifrostResp *schemas.BifrostResponsesResponse) (* } if len(bifrostResp.Output) > 0 { - for _, outputMsg := range bifrostResp.Output { - // Check if this output message contains a tool use - if outputMsg.Type != nil && *outputMsg.Type == schemas.ResponsesMessageTypeFunctionCall { + // Convert Bifrost messages back to Bedrock messages using the new conversion method + bedrockMessages, _, err := ConvertBifrostMessagesToBedrockMessages(bifrostResp.Output) + if err != nil { + return nil, fmt.Errorf("failed to convert bifrost output messages: %w", err) + } + + // Merge all content blocks from converted messages into a single message + for _, bedrockMsg := range bedrockMessages { + message.Content = append(message.Content, bedrockMsg.Content...) + } + + // Check for tool use in the content blocks + for _, block := range message.Content { + if block.ToolUse != nil { hasToolUse = true + break } - // Handle content blocks - if outputMsg.Content != nil && outputMsg.Content.ContentBlocks != nil { - for _, content := range outputMsg.Content.ContentBlocks { - switch content.Type { - case schemas.ResponsesOutputMessageContentTypeText: - if content.Text != nil { - message.Content = append(message.Content, BedrockContentBlock{ - Text: content.Text, - }) - } - } - } - } + } + } - // Handle tool calls and tool results - if outputMsg.ResponsesToolMessage != nil { - // Check if this is a tool use (function_call type) - if outputMsg.Type != nil && *outputMsg.Type == schemas.ResponsesMessageTypeFunctionCall { - // This is a tool use - ensure we have required fields - if outputMsg.ResponsesToolMessage.Name != nil && outputMsg.ResponsesToolMessage.CallID != nil { - var input interface{} = map[string]interface{}{} - if outputMsg.ResponsesToolMessage.Arguments != nil { - var parsed interface{} - if err := json.Unmarshal([]byte(*outputMsg.ResponsesToolMessage.Arguments), &parsed); err == nil { - input = parsed - } else { - // Fallback to raw string if it's not valid JSON - input = *outputMsg.ResponsesToolMessage.Arguments - } - } - message.Content = append(message.Content, BedrockContentBlock{ - ToolUse: &BedrockToolUse{ - ToolUseID: *outputMsg.ResponsesToolMessage.CallID, - Name: *outputMsg.ResponsesToolMessage.Name, - Input: input, - }, - }) - } - } else if outputMsg.Type != nil && *outputMsg.Type == schemas.ResponsesMessageTypeFunctionCallOutput { - // This is a tool result - ensure we have required fields - if outputMsg.ResponsesToolMessage.CallID != nil && outputMsg.ResponsesToolMessage.Output != nil { - resultBlock := BedrockContentBlock{ - ToolResult: &BedrockToolResult{ - ToolUseID: *outputMsg.ResponsesToolMessage.CallID, - Status: schemas.Ptr("success"), + bedrockResp.Output.Message = message + + // Find stop reason from incomplete details or derive from response + // Priority: IncompleteDetails > tool_use detection > end_turn + stopReason := "end_turn" + if bifrostResp.IncompleteDetails != nil { + stopReason = bifrostResp.IncompleteDetails.Reason + } else if hasToolUse { + stopReason = "tool_use" + } + bedrockResp.StopReason = stopReason + + // Convert usage stats + if bifrostResp.Usage != nil { + bedrockResp.Usage.InputTokens = bifrostResp.Usage.InputTokens + bedrockResp.Usage.OutputTokens = bifrostResp.Usage.OutputTokens + bedrockResp.Usage.TotalTokens = bifrostResp.Usage.TotalTokens + } + + // Set metrics + if bifrostResp.ExtraFields.Latency > 0 { + bedrockResp.Metrics.LatencyMs = bifrostResp.ExtraFields.Latency + } + + return bedrockResp, nil +} + +// Helper functions + +// ensureResponsesToolConfigForConversation ensures toolConfig is present when tool content exists +func ensureResponsesToolConfigForConversation(bifrostReq *schemas.BifrostResponsesRequest, bedrockReq *BedrockConverseRequest) { + if bedrockReq.ToolConfig != nil { + return // Already has tool config + } + + hasToolContent, tools := extractToolsFromResponsesConversationHistory(bifrostReq.Input) + if hasToolContent && len(tools) > 0 { + bedrockReq.ToolConfig = &BedrockToolConfig{Tools: tools} + } +} + +// extractToolsFromResponsesConversationHistory extracts tools from Responses conversation history +func extractToolsFromResponsesConversationHistory(messages []schemas.ResponsesMessage) (bool, []BedrockTool) { + var hasToolContent bool + toolMap := make(map[string]*schemas.ResponsesTool) // Use map to deduplicate by name + + for _, msg := range messages { + // Check if message contains tool use or tool result + if msg.Type != nil { + switch *msg.Type { + case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeFunctionCallOutput: + hasToolContent = true + // Try to infer tool definition from tool call/result + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { + toolName := *msg.ResponsesToolMessage.Name + if _, exists := toolMap[toolName]; !exists { + // Create a minimal tool definition + toolMap[toolName] = &schemas.ResponsesTool{ + Type: "function", + Name: &toolName, + ResponsesToolFunction: &schemas.ResponsesToolFunction{ + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{}, + }, }, } - var resultContent []BedrockContentBlock - if outputMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { - raw := *outputMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr - var parsed interface{} - if err := json.Unmarshal([]byte(raw), &parsed); err == nil { - resultContent = append(resultContent, BedrockContentBlock{ - JSON: parsed, - }) - } else { - // Fallback to raw string if it's not valid JSON - resultContent = append(resultContent, BedrockContentBlock{ - Text: &raw, - }) - } - } else if outputMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { - converted, err := convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(schemas.ResponsesMessageContent{ - ContentBlocks: outputMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks, - }) - if err == nil { - resultContent = append(resultContent, converted...) - } else { - // Fallback to JSON string if conversion fails - fallback := schemas.JsonifyInput(outputMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) - resultContent = append(resultContent, BedrockContentBlock{ - Text: &fallback, - }) - } - } - if len(resultContent) > 0 { - resultBlock.ToolResult.Content = resultContent - message.Content = append(message.Content, resultBlock) - } } } } } + } - // Also check the final message content for tool use blocks (more robust) - if !hasToolUse { - for _, block := range message.Content { - if block.ToolUse != nil { - hasToolUse = true - break + // Convert map to slice + var tools []BedrockTool + for _, tool := range toolMap { + if tool.Name != nil && tool.ResponsesToolFunction != nil { + schemaObject := tool.ResponsesToolFunction.Parameters + if schemaObject == nil { + schemaObject = &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{}, } } - } - } - - bedrockResp.Output.Message = message - - // Find stop reason from incomplete details or derive from response - // Priority: IncompleteDetails > tool_use detection > end_turn - stopReason := "end_turn" - if bifrostResp.IncompleteDetails != nil { - stopReason = bifrostResp.IncompleteDetails.Reason - } else if hasToolUse { - stopReason = "tool_use" - } - bedrockResp.StopReason = stopReason - // Convert usage stats - if bifrostResp.Usage != nil { - bedrockResp.Usage.InputTokens = bifrostResp.Usage.InputTokens - bedrockResp.Usage.OutputTokens = bifrostResp.Usage.OutputTokens - bedrockResp.Usage.TotalTokens = bifrostResp.Usage.TotalTokens - } + description := "Function tool" + if tool.Description != nil { + description = *tool.Description + } - // Set metrics - if bifrostResp.ExtraFields.Latency > 0 { - bedrockResp.Metrics.LatencyMs = bifrostResp.ExtraFields.Latency + bedrockTool := BedrockTool{ + ToolSpec: &BedrockToolSpec{ + Name: *tool.Name, + Description: &description, + InputSchema: BedrockToolInputSchema{ + JSON: schemaObject, + }, + }, + } + tools = append(tools, bedrockTool) + } } - return bedrockResp, nil + return hasToolContent, tools } -// ToBifrostResponsesStream converts a Bedrock stream event to a Bifrost Responses Stream response -// Returns a slice of responses to support cases where a single event produces multiple responses -func (chunk *BedrockStreamEvent) ToBifrostResponsesStream(sequenceNumber int, state *BedrockResponsesStreamState) ([]*schemas.BifrostResponsesStreamResponse, *schemas.BifrostError, bool) { - switch { - case chunk.Role != nil: - // Message start - emit response.created and response.in_progress (OpenAI-style lifecycle) - var responses []*schemas.BifrostResponsesStreamResponse - - // Generate message ID if not already set - if state.MessageID == nil { - messageID := fmt.Sprintf("msg_%d", state.CreatedAt) - state.MessageID = &messageID - } - - // Emit response.created - if !state.HasEmittedCreated { - response := &schemas.BifrostResponsesResponse{ - ID: state.MessageID, - CreatedAt: state.CreatedAt, - } - if state.Model != nil { - response.Model = *state.Model +func convertResponsesToolChoice(toolChoice schemas.ResponsesToolChoice) *BedrockToolChoice { + // Check if it's a string choice + if toolChoice.ResponsesToolChoiceStr != nil { + switch schemas.ResponsesToolChoiceType(*toolChoice.ResponsesToolChoiceStr) { + case schemas.ResponsesToolChoiceTypeAny, schemas.ResponsesToolChoiceTypeRequired: + return &BedrockToolChoice{ + Any: &BedrockToolChoiceAny{}, } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeCreated, - SequenceNumber: sequenceNumber, - Response: response, - }) - state.HasEmittedCreated = true + case schemas.ResponsesToolChoiceTypeNone: + // Bedrock doesn't have explicit "none" - just don't include tools + return nil } + } - // Emit response.in_progress - if !state.HasEmittedInProgress { - response := &schemas.BifrostResponsesResponse{ - ID: state.MessageID, - CreatedAt: state.CreatedAt, // Use same timestamp + // Check if it's a struct choice + if toolChoice.ResponsesToolChoiceStruct != nil { + switch toolChoice.ResponsesToolChoiceStruct.Type { + case schemas.ResponsesToolChoiceTypeFunction: + // Extract the actual function name from the struct + if toolChoice.ResponsesToolChoiceStruct.Name != nil && *toolChoice.ResponsesToolChoiceStruct.Name != "" { + return &BedrockToolChoice{ + Tool: &BedrockToolChoiceTool{ + Name: *toolChoice.ResponsesToolChoiceStruct.Name, + }, + } } - if state.Model != nil { - response.Model = *state.Model + // If Name is nil or empty, return nil as we can't construct a valid tool choice + return nil + case schemas.ResponsesToolChoiceTypeAuto, schemas.ResponsesToolChoiceTypeAny, schemas.ResponsesToolChoiceTypeRequired: + return &BedrockToolChoice{ + Any: &BedrockToolChoiceAny{}, } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeInProgress, - SequenceNumber: sequenceNumber + len(responses), - Response: response, - }) - state.HasEmittedInProgress = true - } - - // Emit output_item.added for text message - outputIndex := 0 - state.ContentIndexToOutputIndex[0] = outputIndex // Text is at content index 0 - - // Generate stable ID for text item - var itemID string - if state.MessageID == nil { - itemID = fmt.Sprintf("item_%d", outputIndex) - } else { - itemID = fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) - } - state.ItemIDs[outputIndex] = itemID - - messageType := schemas.ResponsesMessageTypeMessage - role := schemas.ResponsesInputMessageRoleAssistant - - item := &schemas.ResponsesMessage{ - ID: &itemID, - Type: &messageType, - Role: &role, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{}, // Empty blocks slice for mutation support - }, - } - - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - Item: item, - }) - - // Emit content_part.added with empty output_text part - emptyText := "" - part := &schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: &emptyText, - } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartAdded, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - ItemID: &itemID, - Part: part, - }) - - if len(responses) > 0 { - return responses, nil, false - } - - case chunk.Start != nil: - // Handle content block start (text content or tool use) - contentBlockIndex := 0 - if chunk.ContentBlockIndex != nil { - contentBlockIndex = *chunk.ContentBlockIndex + case schemas.ResponsesToolChoiceTypeNone: + return nil } + } - // Check if this is a tool use start - if chunk.Start.ToolUse != nil { - // Close text item if it's still open - var responses []*schemas.BifrostResponsesStreamResponse - if !state.TextItemClosed { - outputIndex := 0 - itemID := state.ItemIDs[outputIndex] + return nil +} - // Emit output_text.done (without accumulated text, just the event) - emptyText := "" - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - ItemID: &itemID, - Text: &emptyText, - }) +// ToolCallState represents the state of a single tool call in the conversion process +type ToolCallState string - // Emit content_part.done - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - ItemID: &itemID, - }) +const ( + // Tool call states + ToolCallStateInitialized ToolCallState = "initialized" // Tool call message received + ToolCallStateQueued ToolCallState = "queued" // Tool call queued for emission + ToolCallStateEmitted ToolCallState = "emitted" // Tool call emitted in assistant message + ToolCallStateAwaitingResult ToolCallState = "awaiting_result" // Waiting for tool result + ToolCallStateCompleted ToolCallState = "completed" // Tool call + result complete +) - // Emit output_item.done - statusCompleted := "completed" - doneItem := &schemas.ResponsesMessage{ - Status: &statusCompleted, - } - if itemID != "" { - doneItem.ID = &itemID - } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - Item: doneItem, - }) - state.TextItemClosed = true - } +// ToolCall represents a tool call with its full lifecycle state +type ToolCall struct { + CallID string + ToolName string + Arguments string + State ToolCallState + AssistantMsgIndex int // Index in final bedrockMessages where this call was emitted + Result *ToolResult +} - // This is a function call starting - use output_index 1 - outputIndex := 1 - state.ContentIndexToOutputIndex[contentBlockIndex] = outputIndex - state.CurrentOutputIndex = 2 // Next available index +// ToolResult represents the result of a tool call +type ToolResult struct { + CallID string + Content []BedrockContentBlock + Status string + Emitted bool +} - // Store tool use ID as item ID and call ID - toolUseID := chunk.Start.ToolUse.ToolUseID - toolName := chunk.Start.ToolUse.Name - state.ItemIDs[outputIndex] = toolUseID - state.ToolCallIDs[outputIndex] = toolUseID - state.ToolCallNames[outputIndex] = toolName +// ToolCallBatch tracks a group of tool calls that should be emitted together +type ToolCallBatch struct { + ID string // Unique batch identifier + ToolCalls map[string]*ToolCall // Maps CallID to ToolCall + State ToolCallState + AssistantMsgIndex int // Where this batch's assistant message is in bedrockMessages +} - statusInProgress := "in_progress" - item := &schemas.ResponsesMessage{ - ID: &toolUseID, - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - Status: &statusInProgress, - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: &toolUseID, - Name: &toolName, - Arguments: schemas.Ptr(""), // Arguments will be filled by deltas - }, - } +// ToolCallStateManager manages the lifecycle of tool calls through conversion +type ToolCallStateManager struct { + // All tool calls indexed by ID + toolCalls map[string]*ToolCall - // Initialize argument buffer for this tool call - state.ToolArgumentBuffers[outputIndex] = "" + // Current batch being accumulated + currentBatch *ToolCallBatch + batches []*ToolCallBatch - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(contentBlockIndex), - Item: item, - }) + // Pending operations + pendingToolCallIDs []string // Tool calls waiting to be emitted + pendingResults map[string]*ToolResult // Results waiting to be matched +} - return responses, nil, false - } - // Text content start is handled by Role event, so we can ignore Start for text +// NewToolCallStateManager creates a new state manager +func NewToolCallStateManager() *ToolCallStateManager { + return &ToolCallStateManager{ + toolCalls: make(map[string]*ToolCall), + pendingResults: make(map[string]*ToolResult), + } +} - case chunk.ContentBlockIndex != nil && chunk.Delta != nil: - // Handle contentBlockDelta event - contentBlockIndex := *chunk.ContentBlockIndex - outputIndex, exists := state.ContentIndexToOutputIndex[contentBlockIndex] - if !exists { - // Default to 0 for text if not mapped - outputIndex = 0 - state.ContentIndexToOutputIndex[contentBlockIndex] = outputIndex - } +// RegisterToolCall registers a new tool call in the system +func (m *ToolCallStateManager) RegisterToolCall(callID, toolName, arguments string) { + if m.toolCalls[callID] != nil { + // Tool call already registered, skip + return + } - switch { - case chunk.Delta.Text != nil: - // Handle text delta - text := *chunk.Delta.Text - if text != "" { - itemID := state.ItemIDs[outputIndex] - response := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: &contentBlockIndex, - Delta: &text, - } - if itemID != "" { - response.ItemID = &itemID - } - return []*schemas.BifrostResponsesStreamResponse{response}, nil, false - } + toolCall := &ToolCall{ + CallID: callID, + ToolName: toolName, + Arguments: arguments, + State: ToolCallStateInitialized, + AssistantMsgIndex: -1, + } - case chunk.Delta.ToolUse != nil: - // Handle tool use delta - function call arguments - toolUseDelta := chunk.Delta.ToolUse + m.toolCalls[callID] = toolCall + m.pendingToolCallIDs = append(m.pendingToolCallIDs, callID) +} - if toolUseDelta.Input != "" { - // Accumulate argument deltas - state.ToolArgumentBuffers[outputIndex] += toolUseDelta.Input +// RegisterToolResult registers a tool result +func (m *ToolCallStateManager) RegisterToolResult(callID string, content []BedrockContentBlock, status string) { + result := &ToolResult{ + CallID: callID, + Content: content, + Status: status, + Emitted: false, + } - itemID := state.ItemIDs[outputIndex] - response := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: &contentBlockIndex, - Delta: &toolUseDelta.Input, - } - if itemID != "" { - response.ItemID = &itemID - } - return []*schemas.BifrostResponsesStreamResponse{response}, nil, false - } + m.pendingResults[callID] = result + + // If we have the corresponding tool call, attach the result + if toolCall, exists := m.toolCalls[callID]; exists { + toolCall.Result = result + if toolCall.State == ToolCallStateEmitted { + toolCall.State = ToolCallStateCompleted + } else if toolCall.State == ToolCallStateAwaitingResult { + toolCall.State = ToolCallStateCompleted } + } +} - case chunk.StopReason != nil: - // Stop reason - don't use it to close items, just return nil - // Items should be closed explicitly when content blocks end - return nil, nil, false +// EmitPendingToolCalls prepares all pending tool calls for emission as an assistant message +func (m *ToolCallStateManager) EmitPendingToolCalls() []string { + if len(m.pendingToolCallIDs) == 0 { + return nil } - return nil, nil, false -} + // Create a new batch for these tool calls + batchID := fmt.Sprintf("batch_%d", len(m.batches)) + batch := &ToolCallBatch{ + ID: batchID, + ToolCalls: make(map[string]*ToolCall), + State: ToolCallStateQueued, + } -// FinalizeBedrockStream finalizes the stream by closing any open items and emitting completed event -func FinalizeBedrockStream(state *BedrockResponsesStreamState, sequenceNumber int, usage *schemas.ResponsesResponseUsage) []*schemas.BifrostResponsesStreamResponse { - var responses []*schemas.BifrostResponsesStreamResponse + // Mark all pending tool calls as queued + for _, callID := range m.pendingToolCallIDs { + if toolCall, exists := m.toolCalls[callID]; exists { + toolCall.State = ToolCallStateQueued + batch.ToolCalls[callID] = toolCall + } + } - // Close text item if still open - if !state.TextItemClosed { - outputIndex := 0 - itemID := state.ItemIDs[outputIndex] + m.batches = append(m.batches, batch) + m.currentBatch = batch - // Emit output_text.done (without accumulated text, just the event) - emptyText := "" - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - ItemID: &itemID, - Text: &emptyText, - }) + // Return the IDs that should be emitted + emitIDs := make([]string, len(m.pendingToolCallIDs)) + copy(emitIDs, m.pendingToolCallIDs) + m.pendingToolCallIDs = nil - // Emit content_part.done - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - ItemID: &itemID, - }) + return emitIDs +} - // Emit output_item.done - statusCompleted := "completed" - doneItem := &schemas.ResponsesMessage{ - Status: &statusCompleted, - } - if itemID != "" { - doneItem.ID = &itemID +// MarkToolCallsEmitted marks tool calls as having been emitted in an assistant message +func (m *ToolCallStateManager) MarkToolCallsEmitted(callIDs []string, assistantMsgIndex int) { + for _, callID := range callIDs { + if toolCall, exists := m.toolCalls[callID]; exists { + toolCall.State = ToolCallStateEmitted + toolCall.AssistantMsgIndex = assistantMsgIndex } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - Item: doneItem, - }) - state.TextItemClosed = true } - // Close any open tool call items and emit function_call_arguments.done - for outputIndex, args := range state.ToolArgumentBuffers { - if args != "" { - itemID := state.ItemIDs[outputIndex] - callID := state.ToolCallIDs[outputIndex] - toolName := state.ToolCallNames[outputIndex] + if m.currentBatch != nil { + m.currentBatch.State = ToolCallStateEmitted + m.currentBatch.AssistantMsgIndex = assistantMsgIndex + } +} - // Create item with tool message info for the done event - var doneItem *schemas.ResponsesMessage - if callID != "" || toolName != "" { - doneItem = &schemas.ResponsesMessage{ - ResponsesToolMessage: &schemas.ResponsesToolMessage{}, - } - if callID != "" { - doneItem.ResponsesToolMessage.CallID = &callID - } - if toolName != "" { - doneItem.ResponsesToolMessage.Name = &toolName - } +// GetPendingResults returns all pending results that are ready to be emitted +func (m *ToolCallStateManager) GetPendingResults() map[string]*ToolResult { + return m.pendingResults +} + +// MarkResultsEmitted marks results as having been emitted in a user message +func (m *ToolCallStateManager) MarkResultsEmitted(callIDs []string) { + for _, callID := range callIDs { + if result, exists := m.pendingResults[callID]; exists { + result.Emitted = true + delete(m.pendingResults, callID) + + // Update tool call state + if toolCall, exists := m.toolCalls[callID]; exists { + toolCall.State = ToolCallStateCompleted } + } + } +} - // Emit function_call_arguments.done with full arguments - response := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - Arguments: &args, +// HasPendingToolCalls checks if there are tool calls waiting to be emitted +func (m *ToolCallStateManager) HasPendingToolCalls() bool { + return len(m.pendingToolCallIDs) > 0 +} + +// HasPendingResults checks if there are results waiting to be emitted +func (m *ToolCallStateManager) HasPendingResults() bool { + return len(m.pendingResults) > 0 +} + +// ConvertBifrostMessagesToBedrockMessages converts an array of Bifrost ResponsesMessage to Bedrock message format +// This is the main conversion method from Bifrost to Bedrock - handles all message types and returns messages + system messages +// Uses a state machine to properly track and manage tool call lifecycles +func ConvertBifrostMessagesToBedrockMessages(bifrostMessages []schemas.ResponsesMessage) ([]BedrockMessage, []BedrockSystemMessage, error) { + var bedrockMessages []BedrockMessage + var systemMessages []BedrockSystemMessage + var pendingReasoningContentBlocks []BedrockContentBlock + + // Initialize the state manager for tracking tool calls and results + stateManager := NewToolCallStateManager() + + // Helper to flush pending tool result blocks into user messages using state manager + flushPendingToolResults := func() { + // Emit any pending results from the state manager + if stateManager.HasPendingResults() { + pendingResults := stateManager.GetPendingResults() + var resultBlocks []BedrockContentBlock + resultIDs := []string{} + for callID, result := range pendingResults { + resultBlocks = append(resultBlocks, BedrockContentBlock{ + ToolResult: &BedrockToolResult{ + ToolUseID: callID, + Content: result.Content, + Status: schemas.Ptr(result.Status), + }, + }) + resultIDs = append(resultIDs, callID) } - if itemID != "" { - response.ItemID = &itemID + + if len(resultBlocks) > 0 { + bedrockMessages = append(bedrockMessages, BedrockMessage{ + Role: BedrockMessageRoleUser, + Content: resultBlocks, + }) + stateManager.MarkResultsEmitted(resultIDs) } - if doneItem != nil { - response.Item = doneItem + } + } + + // Helper to flush pending tool call blocks into a single assistant message using state manager + flushPendingToolCalls := func() { + if stateManager.HasPendingToolCalls() { + callIDs := stateManager.EmitPendingToolCalls() + // Create assistant message with tool calls + var contentBlocks []BedrockContentBlock + + // Prepend pending reasoning blocks first (Bedrock requires reasoning before tool_use) + if len(pendingReasoningContentBlocks) > 0 { + contentBlocks = append(contentBlocks, pendingReasoningContentBlocks...) + pendingReasoningContentBlocks = nil } - responses = append(responses, response) - // Emit output_item.done for function call - statusCompleted := "completed" - outputItemDone := &schemas.ResponsesMessage{ - Status: &statusCompleted, + // Add tool use blocks + for _, callID := range callIDs { + if toolCall, exists := stateManager.toolCalls[callID]; exists { + toolUseBlock := &BedrockContentBlock{ + ToolUse: &BedrockToolUse{ + ToolUseID: toolCall.CallID, + Name: toolCall.ToolName, + }, + } + // Parse arguments + var input interface{} + if err := sonic.Unmarshal([]byte(toolCall.Arguments), &input); err != nil { + input = map[string]interface{}{} + } + toolUseBlock.ToolUse.Input = input + contentBlocks = append(contentBlocks, *toolUseBlock) + } } - if itemID != "" { - outputItemDone.ID = &itemID + + if len(contentBlocks) > 0 { + bedrockMessages = append(bedrockMessages, BedrockMessage{ + Role: BedrockMessageRoleAssistant, + Content: contentBlocks, + }) + stateManager.MarkToolCallsEmitted(callIDs, len(bedrockMessages)-1) } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - Item: outputItemDone, - }) } } - // Emit response.completed - response := &schemas.BifrostResponsesResponse{ - ID: state.MessageID, - CreatedAt: state.CreatedAt, - Usage: usage, - } - - if state.Model != nil { - response.Model = *state.Model - } + for i, msg := range bifrostMessages { + // Handle nil Type as regular message + msgType := schemas.ResponsesMessageTypeMessage + if msg.Type != nil { + msgType = *msg.Type + } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeCompleted, - SequenceNumber: sequenceNumber + len(responses), - Response: response, - }) + // If we're processing a non-reasoning message and have pending reasoning blocks, + // flush them into the previous assistant message (if it exists) + if msgType != schemas.ResponsesMessageTypeReasoning && len(pendingReasoningContentBlocks) > 0 { + if len(bedrockMessages) > 0 && bedrockMessages[len(bedrockMessages)-1].Role == BedrockMessageRoleAssistant { + // Prepend reasoning blocks to the last assistant message + lastMsg := &bedrockMessages[len(bedrockMessages)-1] + lastMsg.Content = append(pendingReasoningContentBlocks, lastMsg.Content...) + pendingReasoningContentBlocks = nil + } + } - return responses -} + switch msgType { + case schemas.ResponsesMessageTypeFunctionCall: + // Register tool call in state manager + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { + toolName := "" + if msg.ResponsesToolMessage.Name != nil { + toolName = *msg.ResponsesToolMessage.Name + } + arguments := "" + if msg.ResponsesToolMessage.Arguments != nil { + arguments = *msg.ResponsesToolMessage.Arguments + } -// ToBedrockConverseStreamResponse converts a Bifrost Responses stream response to Bedrock streaming format -// Returns a BedrockStreamEvent that represents the streaming event in Bedrock's format -func ToBedrockConverseStreamResponse(bifrostResp *schemas.BifrostResponsesStreamResponse) (*BedrockStreamEvent, error) { - if bifrostResp == nil { - return nil, fmt.Errorf("bifrost stream response is nil") - } + stateManager.RegisterToolCall(*msg.ResponsesToolMessage.CallID, toolName, arguments) + } - event := &BedrockStreamEvent{} + case schemas.ResponsesMessageTypeFunctionCallOutput: + // Register tool result in state manager + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { + resultContent := []BedrockContentBlock{} + status := "success" + if msg.Status != nil && *msg.Status != "" { + // Validate status is one of the allowed values + switch *msg.Status { + case "success", "error": + status = *msg.Status + default: + // Default to success for unknown status values + status = "success" + } + } - switch bifrostResp.Type { - case schemas.ResponsesStreamResponseTypeCreated: - // Message start - emit role event - // Always set role for message start event - role := "assistant" - event.Role = &role + // Convert result content to Bedrock format + if msg.ResponsesToolMessage.Output != nil { + if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { + // Try to parse as JSON, otherwise treat as text + var parsed interface{} + if err := json.Unmarshal([]byte(*msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr), &parsed); err != nil { + resultContent = append(resultContent, BedrockContentBlock{ + Text: msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr, + }) + } else { + resultContent = append(resultContent, BedrockContentBlock{ + JSON: parsed, + }) + } + } else if msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { + // Handle structured output blocks + for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { + switch block.Type { + case schemas.ResponsesOutputMessageContentTypeText: + if block.Text != nil { + resultContent = append(resultContent, BedrockContentBlock{ + Text: block.Text, + }) + } + } + } + } + } - case schemas.ResponsesStreamResponseTypeInProgress: - // In progress - no-op for Bedrock (it doesn't have an explicit in_progress event) - // Return nil to skip this event - return nil, nil + stateManager.RegisterToolResult(*msg.ResponsesToolMessage.CallID, resultContent, status) + } - case schemas.ResponsesStreamResponseTypeOutputItemAdded: - // Content block start - if bifrostResp.Item != nil && bifrostResp.Item.ResponsesToolMessage != nil { - // Tool use start - if bifrostResp.Item.ResponsesToolMessage.Name != nil && bifrostResp.Item.ResponsesToolMessage.CallID != nil { - contentBlockIndex := 0 - if bifrostResp.ContentIndex != nil { - contentBlockIndex = *bifrostResp.ContentIndex + // Check if next message is not a function call output - if so, flush tool calls and results + isLastResultInSequence := true + if i+1 < len(bifrostMessages) { + nextMsg := bifrostMessages[i+1] + nextMsgType := schemas.ResponsesMessageTypeMessage + if nextMsg.Type != nil { + nextMsgType = *nextMsg.Type } - event.ContentBlockIndex = &contentBlockIndex - event.Start = &BedrockContentBlockStart{ - ToolUse: &BedrockToolUseStart{ - ToolUseID: *bifrostResp.Item.ResponsesToolMessage.CallID, - Name: *bifrostResp.Item.ResponsesToolMessage.Name, - }, + if nextMsgType == schemas.ResponsesMessageTypeFunctionCallOutput { + isLastResultInSequence = false } } - } else if bifrostResp.Item != nil { - // Text item added - Bedrock doesn't have an explicit text start event, so we skip it - // Check if it's a text message (has content blocks or is a message type) - if bifrostResp.Item.Content != nil || (bifrostResp.Item.Type != nil && *bifrostResp.Item.Type == schemas.ResponsesMessageTypeMessage) { - return nil, nil - } - } - case schemas.ResponsesStreamResponseTypeOutputTextDelta: - // Text delta - if bifrostResp.Delta != nil && *bifrostResp.Delta != "" { - contentBlockIndex := 0 - if bifrostResp.ContentIndex != nil { - contentBlockIndex = *bifrostResp.ContentIndex - } - event.ContentBlockIndex = &contentBlockIndex - event.Delta = &BedrockContentBlockDelta{ - Text: bifrostResp.Delta, + // If this is the last result in a sequence, flush tool calls and results together + if isLastResultInSequence { + // Emit pending tool calls first + if stateManager.HasPendingToolCalls() { + callIDs := stateManager.EmitPendingToolCalls() + var contentBlocks []BedrockContentBlock + + // Prepend pending reasoning blocks first (Bedrock requires reasoning before tool_use) + if len(pendingReasoningContentBlocks) > 0 { + contentBlocks = append(contentBlocks, pendingReasoningContentBlocks...) + pendingReasoningContentBlocks = nil + } + + // Add tool use blocks + for _, callID := range callIDs { + if toolCall, exists := stateManager.toolCalls[callID]; exists { + toolUseBlock := &BedrockContentBlock{ + ToolUse: &BedrockToolUse{ + ToolUseID: toolCall.CallID, + Name: toolCall.ToolName, + }, + } + var input interface{} + if err := sonic.Unmarshal([]byte(toolCall.Arguments), &input); err != nil { + input = map[string]interface{}{} + } + toolUseBlock.ToolUse.Input = input + contentBlocks = append(contentBlocks, *toolUseBlock) + } + } + + if len(contentBlocks) > 0 { + bedrockMessages = append(bedrockMessages, BedrockMessage{ + Role: BedrockMessageRoleAssistant, + Content: contentBlocks, + }) + stateManager.MarkToolCallsEmitted(callIDs, len(bedrockMessages)-1) + } + } + + // Emit pending results after tool calls + if stateManager.HasPendingResults() { + pendingResults := stateManager.GetPendingResults() + var resultBlocks []BedrockContentBlock + resultIDs := []string{} + for callID, result := range pendingResults { + resultBlocks = append(resultBlocks, BedrockContentBlock{ + ToolResult: &BedrockToolResult{ + ToolUseID: callID, + Content: result.Content, + Status: schemas.Ptr(result.Status), + }, + }) + resultIDs = append(resultIDs, callID) + } + + if len(resultBlocks) > 0 { + bedrockMessages = append(bedrockMessages, BedrockMessage{ + Role: BedrockMessageRoleUser, + Content: resultBlocks, + }) + stateManager.MarkResultsEmitted(resultIDs) + } + } } - } else { - // Skip empty deltas - return nil, nil - } - case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta: - // Tool use delta (function call arguments) - if bifrostResp.Delta != nil { - contentBlockIndex := 0 - if bifrostResp.ContentIndex != nil { - contentBlockIndex = *bifrostResp.ContentIndex + case schemas.ResponsesMessageTypeMessage: + // Check if Role is present, skip message if not + if msg.Role == nil { + continue } - event.ContentBlockIndex = &contentBlockIndex - event.Delta = &BedrockContentBlockDelta{ - ToolUse: &BedrockToolUseDelta{ - Input: *bifrostResp.Delta, - }, + + // Extract role from the Responses message structure + role := *msg.Role + + // Always flush pending tool calls and results before processing a new message + // This ensures tool calls and results are properly paired + if stateManager.HasPendingToolCalls() { + callIDs := stateManager.EmitPendingToolCalls() + // Create assistant message with tool calls + var toolUseBlocks []BedrockContentBlock + for _, callID := range callIDs { + if toolCall, exists := stateManager.toolCalls[callID]; exists { + toolUseBlock := &BedrockContentBlock{ + ToolUse: &BedrockToolUse{ + ToolUseID: toolCall.CallID, + Name: toolCall.ToolName, + }, + } + // Parse arguments + var input interface{} + if err := sonic.Unmarshal([]byte(toolCall.Arguments), &input); err != nil { + input = map[string]interface{}{} + } + toolUseBlock.ToolUse.Input = input + toolUseBlocks = append(toolUseBlocks, *toolUseBlock) + } + } + + if len(toolUseBlocks) > 0 { + bedrockMessages = append(bedrockMessages, BedrockMessage{ + Role: BedrockMessageRoleAssistant, + Content: toolUseBlocks, + }) + stateManager.MarkToolCallsEmitted(callIDs, len(bedrockMessages)-1) + } } - } - case schemas.ResponsesStreamResponseTypeOutputTextDone, schemas.ResponsesStreamResponseTypeContentPartDone: - // Content block done - Bedrock doesn't have explicit done events, so we skip them - return nil, nil + // Emit any pending results after tool calls + if stateManager.HasPendingResults() { + pendingResults := stateManager.GetPendingResults() + var resultBlocks []BedrockContentBlock + resultIDs := []string{} + for callID, result := range pendingResults { + resultBlocks = append(resultBlocks, BedrockContentBlock{ + ToolResult: &BedrockToolResult{ + ToolUseID: callID, + Content: result.Content, + Status: schemas.Ptr(result.Status), + }, + }) + resultIDs = append(resultIDs, callID) + } + + if len(resultBlocks) > 0 { + bedrockMessages = append(bedrockMessages, BedrockMessage{ + Role: BedrockMessageRoleUser, + Content: resultBlocks, + }) + stateManager.MarkResultsEmitted(resultIDs) + } + } - case schemas.ResponsesStreamResponseTypeOutputItemDone: - // Item done - Bedrock doesn't have explicit done events, so we skip them - return nil, nil + // Convert regular message + if role == schemas.ResponsesInputMessageRoleSystem { + // Convert to system message + systemMsgs := convertBifrostMessageToBedrockSystemMessages(&msg) + systemMessages = append(systemMessages, systemMsgs...) + } else { + // Convert user/assistant text message + bedrockMsg := convertBifrostMessageToBedrockMessage(&msg) + if bedrockMsg != nil { + bedrockMessages = append(bedrockMessages, *bedrockMsg) + } + } - case schemas.ResponsesStreamResponseTypeCompleted: - // Message stop - always set stopReason - stopReason := "end_turn" - if bifrostResp.Response != nil && bifrostResp.Response.IncompleteDetails != nil { - stopReason = bifrostResp.Response.IncompleteDetails.Reason + case schemas.ResponsesMessageTypeReasoning: + // Handle reasoning as content in next assistant message + // For now, just add to pending content blocks + reasoningBlocks := convertBifrostReasoningToBedrockReasoning(&msg) + if len(reasoningBlocks) > 0 { + pendingReasoningContentBlocks = append(pendingReasoningContentBlocks, reasoningBlocks...) + } } - event.StopReason = &stopReason + } - // Add usage if available - if bifrostResp.Response != nil && bifrostResp.Response.Usage != nil { - event.Usage = &BedrockTokenUsage{ - InputTokens: bifrostResp.Response.Usage.InputTokens, - OutputTokens: bifrostResp.Response.Usage.OutputTokens, - TotalTokens: bifrostResp.Response.Usage.TotalTokens, - } + // Flush any remaining pending tool calls + flushPendingToolCalls() + + // Flush any remaining pending tool results + flushPendingToolResults() + + // For Bedrock compatibility, reasoning blocks must not be the final block in an assistant message + // If we have pending reasoning blocks and the last message is an assistant message, + // merge them into a single message with reasoning first + if len(pendingReasoningContentBlocks) > 0 { + if len(bedrockMessages) > 0 && bedrockMessages[len(bedrockMessages)-1].Role == BedrockMessageRoleAssistant { + // Last message is an assistant message - prepend reasoning blocks to it + lastMsg := &bedrockMessages[len(bedrockMessages)-1] + lastMsg.Content = append(pendingReasoningContentBlocks, lastMsg.Content...) + pendingReasoningContentBlocks = nil } + // If no assistant message to merge into, discard the reasoning blocks + // (they cannot exist alone in Bedrock without violating the constraint) + } - case schemas.ResponsesStreamResponseTypeError: - // Error - errors are handled separately by the router via BifrostError in the stream chunk - // Return nil to skip this chunk - return nil, nil + return bedrockMessages, systemMessages, nil +} - default: - // Unknown type - skip - return nil, nil +// ConvertBedrockMessagesToBifrostMessages converts an array of Bedrock messages to Bifrost ResponsesMessage format +// This is the main conversion method from Bedrock to Bifrost - handles all message types and content blocks +func ConvertBedrockMessagesToBifrostMessages(bedrockMessages []BedrockMessage, systemMessages []BedrockSystemMessage, isOutputMessage bool) []schemas.ResponsesMessage { + var bifrostMessages []schemas.ResponsesMessage + + // Convert system messages first + for _, sysMsg := range systemMessages { + systemBifrostMsgs := convertBedrockSystemMessageToBifrostMessages(&sysMsg) + bifrostMessages = append(bifrostMessages, systemBifrostMsgs...) } - return event, nil -} + // Convert regular messages + for _, msg := range bedrockMessages { + convertedMessages := convertSingleBedrockMessageToBifrostMessages(&msg, isOutputMessage) + bifrostMessages = append(bifrostMessages, convertedMessages...) + } -// BedrockEncodedEvent represents a single event ready for encoding to AWS Event Stream -type BedrockEncodedEvent struct { - EventType string - Payload interface{} + return bifrostMessages } -// BedrockInvokeStreamChunkEvent represents the chunk event for invoke-with-response-stream -type BedrockInvokeStreamChunkEvent struct { - Bytes []byte `json:"bytes"` +// Helper functions for converting individual Bedrock message types + +// convertBifrostMessageToBedrockSystemMessages converts a Bifrost system message to Bedrock system messages +func convertBifrostMessageToBedrockSystemMessages(msg *schemas.ResponsesMessage) []BedrockSystemMessage { + var systemMessages []BedrockSystemMessage + + if msg.Content != nil { + if msg.Content.ContentStr != nil { + systemMessages = append(systemMessages, BedrockSystemMessage{ + Text: msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + if block.Text != nil { + systemMessages = append(systemMessages, BedrockSystemMessage{ + Text: block.Text, + }) + } + } + } + } + + return systemMessages } -// ToEncodedEvents converts the flat BedrockStreamEvent into a sequence of specific events -func (event *BedrockStreamEvent) ToEncodedEvents() []BedrockEncodedEvent { - var events []BedrockEncodedEvent +// convertBifrostMessageToBedrockMessage converts a regular Bifrost message to Bedrock message +func convertBifrostMessageToBedrockMessage(msg *schemas.ResponsesMessage) *BedrockMessage { + // Ensure Content is present + if msg.Content == nil { + return nil + } - if event.InvokeModelRawChunk != nil { - events = append(events, BedrockEncodedEvent{ - EventType: "chunk", - Payload: BedrockInvokeStreamChunkEvent{ - Bytes: event.InvokeModelRawChunk, - }, - }) + bedrockMsg := BedrockMessage{ + Role: BedrockMessageRole(*msg.Role), } - if event.Role != nil { - events = append(events, BedrockEncodedEvent{ - EventType: "messageStart", - Payload: BedrockMessageStartEvent{ - Role: *event.Role, - }, - }) + // Convert content + contentBlocks, err := convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(*msg.Content) + if err != nil { + return nil } + bedrockMsg.Content = contentBlocks - if event.Start != nil { - events = append(events, BedrockEncodedEvent{ - EventType: "contentBlockStart", - Payload: struct { - Start *BedrockContentBlockStart `json:"start"` - ContentBlockIndex *int `json:"contentBlockIndex"` - }{ - Start: event.Start, - ContentBlockIndex: event.ContentBlockIndex, + return &bedrockMsg +} + +// convertBedrockSystemMessageToBifrostMessages converts a Bedrock system message to Bifrost messages +func convertBedrockSystemMessageToBifrostMessages(sysMsg *BedrockSystemMessage) []schemas.ResponsesMessage { + if sysMsg.Text != nil { + systemRole := schemas.ResponsesInputMessageRoleSystem + msgType := schemas.ResponsesMessageTypeMessage + return []schemas.ResponsesMessage{{ + Type: &msgType, + Role: &systemRole, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: sysMsg.Text, + }, + }, }, - }) + }} } + return []schemas.ResponsesMessage{} +} - if event.Delta != nil { - events = append(events, BedrockEncodedEvent{ - EventType: "contentBlockDelta", - Payload: struct { - Delta *BedrockContentBlockDelta `json:"delta"` - ContentBlockIndex *int `json:"contentBlockIndex"` - }{ - Delta: event.Delta, - ContentBlockIndex: event.ContentBlockIndex, - }, - }) +// convertSingleBedrockMessageToBifrostMessages converts a single Bedrock message to Bifrost messages +func convertSingleBedrockMessageToBifrostMessages(msg *BedrockMessage, isOutputMessage bool) []schemas.ResponsesMessage { + var outputMessages []schemas.ResponsesMessage + var reasoningContentBlocks []schemas.ResponsesMessageContentBlock + + for _, block := range msg.Content { + if block.Text != nil { + // Text content + var role schemas.ResponsesMessageRoleType + switch msg.Role { + case BedrockMessageRoleUser: + role = schemas.ResponsesInputMessageRoleUser + case BedrockMessageRoleAssistant: + role = schemas.ResponsesInputMessageRoleAssistant + default: + role = schemas.ResponsesInputMessageRoleUser + } + + // For assistant messages (previous model outputs), use output_text type + // For user/system messages, use input_text type + textBlockType := schemas.ResponsesInputMessageContentBlockTypeText + if isOutputMessage || msg.Role == BedrockMessageRoleAssistant { + textBlockType = schemas.ResponsesOutputMessageContentTypeText + } + + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: &role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: textBlockType, + Text: block.Text, + }, + }, + }, + } + if isOutputMessage { + bifrostMsg.ID = schemas.Ptr("msg_" + fmt.Sprintf("%d", time.Now().UnixNano())) + } + outputMessages = append(outputMessages, bifrostMsg) + + } else if block.ReasoningContent != nil { + // Reasoning content - collect to create a single reasoning message + if block.ReasoningContent.ReasoningText != nil { + reasoningContentBlocks = append(reasoningContentBlocks, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: &block.ReasoningContent.ReasoningText.Text, + Signature: block.ReasoningContent.ReasoningText.Signature, + }) + } + + } else if block.ToolUse != nil { + // Tool use content + // Create copies of the values to avoid range loop variable capture + toolUseID := block.ToolUse.ToolUseID + toolUseName := block.ToolUse.Name + + toolMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &toolUseID, + Name: &toolUseName, + Arguments: schemas.Ptr(schemas.JsonifyInput(block.ToolUse.Input)), + }, + } + if isOutputMessage { + toolMsg.ID = schemas.Ptr("msg_" + fmt.Sprintf("%d", time.Now().UnixNano())) + role := schemas.ResponsesInputMessageRoleAssistant + toolMsg.Role = &role + } + outputMessages = append(outputMessages, toolMsg) + + } else if block.ToolResult != nil { + // Tool result content - typically not in assistant output but handled for completeness + // Prefer JSON payloads without unmarshalling; fallback to text + var resultContent string + if len(block.ToolResult.Content) > 0 { + // JSON first (no unmarshal; just one marshal to string when present) + for _, c := range block.ToolResult.Content { + if c.JSON != nil { + resultContent = schemas.JsonifyInput(c.JSON) + break + } + } + // Fallback to first available text block + if resultContent == "" { + for _, c := range block.ToolResult.Content { + if c.Text != nil { + resultContent = *c.Text + break + } + } + } + } + + // Create a copy of the value to avoid range loop variable capture + toolResultID := block.ToolResult.ToolUseID + + resultMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &toolResultID, + Output: &schemas.ResponsesToolMessageOutputStruct{ + ResponsesToolCallOutputStr: &resultContent, + }, + }, + } + if isOutputMessage { + resultMsg.ID = schemas.Ptr("msg_" + fmt.Sprintf("%d", time.Now().UnixNano())) + role := schemas.ResponsesInputMessageRoleAssistant + resultMsg.Role = &role + resultMsg.Content = &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &resultContent, + }, + }, + } + } + outputMessages = append(outputMessages, resultMsg) + } } - if event.StopReason != nil { - events = append(events, BedrockEncodedEvent{ - EventType: "messageStop", - Payload: BedrockMessageStopEvent{ - StopReason: *event.StopReason, + // Handle reasoning blocks - prepend reasoning message if we collected any + if len(reasoningContentBlocks) > 0 { + reasoningMessage := schemas.ResponsesMessage{ + ID: schemas.Ptr("rs_" + fmt.Sprintf("%d", time.Now().UnixNano())), + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, }, - }) + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: reasoningContentBlocks, + }, + } + // Prepend the reasoning message to the start of the messages list + outputMessages = append([]schemas.ResponsesMessage{reasoningMessage}, outputMessages...) } - if event.Usage != nil || event.Metrics != nil { - events = append(events, BedrockEncodedEvent{ - EventType: "metadata", - Payload: BedrockMetadataEvent{ - Usage: event.Usage, - Metrics: event.Metrics, - Trace: event.Trace, - }, + return outputMessages +} + +// convertBifrostReasoningToBedrockReasoning converts a Bifrost reasoning message to Bedrock reasoning blocks +func convertBifrostReasoningToBedrockReasoning(msg *schemas.ResponsesMessage) []BedrockContentBlock { + var reasoningBlocks []BedrockContentBlock + + if msg.Content != nil && msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + if block.Type == schemas.ResponsesOutputMessageContentTypeReasoning && block.Text != nil { + reasoningBlock := BedrockContentBlock{ + ReasoningContent: &BedrockReasoningContent{ + ReasoningText: &BedrockReasoningContentText{ + Text: *block.Text, + Signature: block.Signature, + }, + }, + } + reasoningBlocks = append(reasoningBlocks, reasoningBlock) + } + } + } else if msg.ResponsesReasoning != nil { + if msg.ResponsesReasoning.Summary != nil { + for _, reasoningContent := range msg.ResponsesReasoning.Summary { + reasoningBlock := BedrockContentBlock{ + ReasoningContent: &BedrockReasoningContent{ + ReasoningText: &BedrockReasoningContentText{ + Text: reasoningContent.Text, + }, + }, + } + reasoningBlocks = append(reasoningBlocks, reasoningBlock) + } + } else if msg.ResponsesReasoning.EncryptedContent != nil { + // Bedrock doesn't have a direct equivalent to encrypted content, + // so we'll store it as a regular reasoning block with a special marker + encryptedText := fmt.Sprintf("[ENCRYPTED_REASONING: %s]", *msg.ResponsesReasoning.EncryptedContent) + reasoningBlock := BedrockContentBlock{ + ReasoningContent: &BedrockReasoningContent{ + ReasoningText: &BedrockReasoningContentText{ + Text: encryptedText, + }, + }, + } + reasoningBlocks = append(reasoningBlocks, reasoningBlock) + } + } + + return reasoningBlocks +} + +// convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks converts Bifrost content to Bedrock content blocks +func convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(content schemas.ResponsesMessageContent) ([]BedrockContentBlock, error) { + var blocks []BedrockContentBlock + + if content.ContentStr != nil { + blocks = append(blocks, BedrockContentBlock{ + Text: content.ContentStr, }) + } else if content.ContentBlocks != nil { + for _, block := range content.ContentBlocks { + + bedrockBlock := BedrockContentBlock{} + + switch block.Type { + case schemas.ResponsesInputMessageContentBlockTypeText, schemas.ResponsesOutputMessageContentTypeText: + bedrockBlock.Text = block.Text + case schemas.ResponsesInputMessageContentBlockTypeImage: + if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { + imageSource, err := convertImageToBedrockSource(*block.ResponsesInputMessageContentBlockImage.ImageURL) + if err != nil { + return nil, fmt.Errorf("failed to convert image in responses content block: %w", err) + } + bedrockBlock.Image = imageSource + } + case schemas.ResponsesOutputMessageContentTypeReasoning: + if block.Text != nil { + bedrockBlock.ReasoningContent = &BedrockReasoningContent{ + ReasoningText: &BedrockReasoningContentText{ + Text: *block.Text, + Signature: block.Signature, + }, + } + } + default: + // Don't add anything for unknown types + continue + } + + blocks = append(blocks, bedrockBlock) + } } - return events + return blocks, nil } diff --git a/core/providers/bedrock/types.go b/core/providers/bedrock/types.go index faa224e1f..ca21116e2 100644 --- a/core/providers/bedrock/types.go +++ b/core/providers/bedrock/types.go @@ -1,9 +1,16 @@ package bedrock -import "github.com/maximhq/bifrost/core/schemas" +import ( + "encoding/json" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) // DefaultBedrockRegion is the default region for Bedrock const DefaultBedrockRegion = "us-east-1" +const MinimumReasoningMaxTokens = 1 +const DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default // ==================== REQUEST TYPES ==================== @@ -59,6 +66,12 @@ type BedrockConverseRequest struct { RequestMetadata map[string]string `json:"requestMetadata,omitempty"` // Request metadata ServiceTier *BedrockServiceTier `json:"serviceTier,omitempty"` // Service tier configuration (note: camelCase in both request and response) Stream bool `json:"-"` // Whether streaming is requested (internal, not in JSON) + + // Extra params for advanced use cases + ExtraParams map[string]interface{} `json:"extra_params,omitempty"` + + // Bifrost specific field (only parsed when converting from Provider -> Bifrost request) + Fallbacks []string `json:"fallbacks,omitempty"` } // IsStreamingRequested implements the StreamingRequest interface @@ -66,6 +79,66 @@ func (r *BedrockConverseRequest) IsStreamingRequested() bool { return r.Stream } +// Known fields for BedrockConverseRequest +var bedrockConverseRequestKnownFields = map[string]bool{ + "messages": true, + "system": true, + "inferenceConfig": true, + "toolConfig": true, + "guardrailConfig": true, + "additionalModelRequestFields": true, + "additionalModelResponseFieldPaths": true, + "performanceConfig": true, + "promptVariables": true, + "requestMetadata": true, + "serviceTier": true, + "stream": true, + "extra_params": true, + "fallbacks": true, +} + +// UnmarshalJSON implements custom JSON unmarshalling for BedrockConverseRequest. +// This captures all unregistered fields into ExtraParams. +func (r *BedrockConverseRequest) UnmarshalJSON(data []byte) error { + // Create an alias type to avoid infinite recursion + type Alias BedrockConverseRequest + + // First, unmarshal into the alias to populate all known fields + aux := &struct { + *Alias + }{ + Alias: (*Alias)(r), + } + + if err := sonic.Unmarshal(data, aux); err != nil { + return err + } + + // Parse JSON to extract unknown fields + var rawData map[string]json.RawMessage + if err := sonic.Unmarshal(data, &rawData); err != nil { + return err + } + + // Initialize ExtraParams if not already initialized + if r.ExtraParams == nil { + r.ExtraParams = make(map[string]interface{}) + } + + // Extract unknown fields + for key, value := range rawData { + if !bedrockConverseRequestKnownFields[key] { + var v interface{} + if err := sonic.Unmarshal(value, &v); err != nil { + continue // Skip fields that can't be unmarshaled + } + r.ExtraParams[key] = v + } + } + + return nil +} + type BedrockMessageRole string const ( diff --git a/core/providers/bedrock/utils.go b/core/providers/bedrock/utils.go index fd7fe92f9..cdec5e367 100644 --- a/core/providers/bedrock/utils.go +++ b/core/providers/bedrock/utils.go @@ -8,11 +8,13 @@ import ( "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/providers/anthropic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" ) // convertParameters handles parameter conversion func convertChatParameters(ctx *context.Context, bifrostReq *schemas.BifrostChatRequest, bedrockReq *BedrockConverseRequest) error { + // Parameters are optional - if not provided, just skip conversion if bifrostReq.Params == nil { return nil } @@ -34,20 +36,43 @@ func convertChatParameters(ctx *context.Context, bifrostReq *schemas.BifrostChat if bedrockReq.AdditionalModelRequestFields == nil { bedrockReq.AdditionalModelRequestFields = make(schemas.OrderedMap) } - if bifrostReq.Params.Reasoning.Effort != nil && *bifrostReq.Params.Reasoning.Effort == "none" { - bedrockReq.AdditionalModelRequestFields["reasoning_config"] = map[string]string{ - "type": "disabled", - } - } else { - if bifrostReq.Params.Reasoning.MaxTokens == nil { - return fmt.Errorf("reasoning.max_tokens is required for reasoning") - } else if schemas.IsAnthropicModel(bedrockReq.ModelID) && *bifrostReq.Params.Reasoning.MaxTokens < anthropic.MinimumReasoningMaxTokens { - return fmt.Errorf("reasoning.max_tokens must be greater than or equal to %d", anthropic.MinimumReasoningMaxTokens) + if bifrostReq.Params.Reasoning.MaxTokens != nil { + if schemas.IsAnthropicModel(bifrostReq.Model) && *bifrostReq.Params.Reasoning.MaxTokens < anthropic.MinimumReasoningMaxTokens { + return fmt.Errorf("reasoning.max_tokens must be >= %d for anthropic", anthropic.MinimumReasoningMaxTokens) } bedrockReq.AdditionalModelRequestFields["reasoning_config"] = map[string]any{ "type": "enabled", "budget_tokens": *bifrostReq.Params.Reasoning.MaxTokens, } + } else if bifrostReq.Params.Reasoning.Effort != nil && *bifrostReq.Params.Reasoning.Effort != "none" { + maxTokens := DefaultCompletionMaxTokens + if bedrockReq.InferenceConfig != nil && bedrockReq.InferenceConfig.MaxTokens != nil { + maxTokens = *bedrockReq.InferenceConfig.MaxTokens + } else { + if bedrockReq.InferenceConfig != nil { + bedrockReq.InferenceConfig.MaxTokens = schemas.Ptr(DefaultCompletionMaxTokens) + } else { + bedrockReq.InferenceConfig = &BedrockInferenceConfig{ + MaxTokens: schemas.Ptr(DefaultCompletionMaxTokens), + } + } + } + minBudgetTokens := MinimumReasoningMaxTokens + if schemas.IsAnthropicModel(bifrostReq.Model) { + minBudgetTokens = anthropic.MinimumReasoningMaxTokens + } + budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, minBudgetTokens, maxTokens) + if err != nil { + return err + } + bedrockReq.AdditionalModelRequestFields["reasoning_config"] = map[string]any{ + "type": "enabled", + "budget_tokens": budgetTokens, + } + } else { + bedrockReq.AdditionalModelRequestFields["reasoning_config"] = map[string]string{ + "type": "disabled", + } } } diff --git a/core/providers/cerebras/cerebras_test.go b/core/providers/cerebras/cerebras_test.go index 174f06a71..188316a00 100644 --- a/core/providers/cerebras/cerebras_test.go +++ b/core/providers/cerebras/cerebras_test.go @@ -31,6 +31,7 @@ func TestCerebras(t *testing.T) { }, TextModel: "llama3.1-8b", EmbeddingModel: "", // Cerebras doesn't support embedding + ReasoningModel: "gpt-oss-120b", Scenarios: testutil.TestScenarios{ TextCompletion: true, TextCompletionStream: true, @@ -48,6 +49,7 @@ func TestCerebras(t *testing.T) { CompleteEnd2End: true, Embedding: false, ListModels: true, + Reasoning: true, }, } diff --git a/core/providers/cohere/chat.go b/core/providers/cohere/chat.go index 052a1fa1d..1d2391b10 100644 --- a/core/providers/cohere/chat.go +++ b/core/providers/cohere/chat.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -100,18 +101,30 @@ func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*Coh cohereReq.FrequencyPenalty = bifrostReq.Params.FrequencyPenalty cohereReq.PresencePenalty = bifrostReq.Params.PresencePenalty + // Convert reasoning if bifrostReq.Params.Reasoning != nil { - if bifrostReq.Params.Reasoning.Effort != nil && *bifrostReq.Params.Reasoning.Effort == "none" { + if bifrostReq.Params.Reasoning.MaxTokens != nil { cohereReq.Thinking = &CohereThinking{ - Type: ThinkingTypeDisabled, + Type: ThinkingTypeEnabled, + TokenBudget: bifrostReq.Params.Reasoning.MaxTokens, } - } else { - if bifrostReq.Params.Reasoning.MaxTokens == nil { - return nil, fmt.Errorf("reasoning.max_tokens is required for reasoning") - } else { + } else if bifrostReq.Params.Reasoning.Effort != nil { + if *bifrostReq.Params.Reasoning.Effort != "none" { + maxCompletionTokens := DefaultCompletionMaxTokens + if bifrostReq.Params.MaxCompletionTokens != nil { + maxCompletionTokens = *bifrostReq.Params.MaxCompletionTokens + } + budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, MinimumReasoningMaxTokens, maxCompletionTokens) + if err != nil { + return nil, err + } cohereReq.Thinking = &CohereThinking{ Type: ThinkingTypeEnabled, - TokenBudget: bifrostReq.Params.Reasoning.MaxTokens, + TokenBudget: schemas.Ptr(budgetTokens), // Max tokens for reasoning + } + } else { + cohereReq.Thinking = &CohereThinking{ + Type: ThinkingTypeDisabled, } } } @@ -426,6 +439,7 @@ func (chunk *CohereStreamEvent) ToBifrostChatCompletionStream() (*schemas.Bifros return streamResponse, nil, false } else if chunk.Delta.Message.Content.CohereStreamContentObject.Thinking != nil { + thinkingText := *chunk.Delta.Message.Content.CohereStreamContentObject.Thinking streamResponse := &schemas.BifrostChatResponse{ Object: "chat.completion.chunk", Choices: []schemas.BifrostResponseChoice{ @@ -433,12 +447,12 @@ func (chunk *CohereStreamEvent) ToBifrostChatCompletionStream() (*schemas.Bifros Index: 0, ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ Delta: &schemas.ChatStreamResponseChoiceDelta{ - Reasoning: chunk.Delta.Message.Content.CohereStreamContentObject.Thinking, + Reasoning: schemas.Ptr(thinkingText), ReasoningDetails: []schemas.ChatReasoningDetails{ { Index: 0, Type: schemas.BifrostReasoningDetailsTypeText, - Text: chunk.Delta.Message.Content.CohereStreamContentObject.Thinking, + Text: schemas.Ptr(thinkingText), }, }, }, diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index 703710d85..f5171971a 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -443,8 +443,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo continue } - chunkIndex++ - // Extract response ID from message-start events if event.Type == StreamEventMessageStart && event.ID != nil { responseID = *event.ID @@ -512,7 +510,7 @@ func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (any, error) { return ToCohereResponsesRequest(request), nil }, + func() (any, error) { return ToCohereResponsesRequest(request) }, provider.GetProviderKey()) if bifrostErr != nil { return nil, bifrostErr @@ -569,7 +567,10 @@ func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRun ctx, request, func() (any, error) { - reqBody := ToCohereResponsesRequest(request) + reqBody, err := ToCohereResponsesRequest(request) + if err != nil { + return nil, err + } if reqBody != nil { reqBody.Stream = schemas.Ptr(true) } diff --git a/core/providers/cohere/cohere_test.go b/core/providers/cohere/cohere_test.go index aa3fce000..ca79483c3 100644 --- a/core/providers/cohere/cohere_test.go +++ b/core/providers/cohere/cohere_test.go @@ -28,6 +28,7 @@ func TestCohere(t *testing.T) { VisionModel: "command-a-vision-07-2025", // Cohere's latest vision model TextModel: "", // Cohere focuses on chat EmbeddingModel: "embed-v4.0", + ReasoningModel: "command-a-reasoning-08-2025", Scenarios: testutil.TestScenarios{ TextCompletion: false, // Not typical for Cohere SimpleChat: true, diff --git a/core/providers/cohere/responses.go b/core/providers/cohere/responses.go index d3907310d..77d080165 100644 --- a/core/providers/cohere/responses.go +++ b/core/providers/cohere/responses.go @@ -6,35 +6,40 @@ import ( "sync" "time" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) // CohereResponsesStreamState tracks state during streaming conversion for responses API type CohereResponsesStreamState struct { - ContentIndexToOutputIndex map[int]int // Maps Cohere content_index to OpenAI output_index - ToolArgumentBuffers map[int]string // Maps output_index to accumulated tool argument JSON - ItemIDs map[int]string // Maps output_index to item ID for stable IDs - CurrentOutputIndex int // Current output index counter - MessageID *string // Message ID from message_start - Model *string // Model name from message_start - CreatedAt int // Timestamp for created_at consistency - HasEmittedCreated bool // Whether we've emitted response.created - HasEmittedInProgress bool // Whether we've emitted response.in_progress - ToolPlanOutputIndex *int // Output index for tool plan text item (if created) + ContentIndexToOutputIndex map[int]int // Maps Cohere content_index to OpenAI output_index + ToolArgumentBuffers map[int]string // Maps output_index to accumulated tool argument JSON + ItemIDs map[int]string // Maps output_index to item ID for stable IDs + ReasoningContentIndices map[int]bool // Tracks which content indices are reasoning blocks + AnnotationIndexToContentIndex map[int]int // Maps annotation index to content index for citation pairing + CurrentOutputIndex int // Current output index counter + MessageID *string // Message ID from message_start + Model *string // Model name from message_start + CreatedAt int // Timestamp for created_at consistency + HasEmittedCreated bool // Whether we've emitted response.created + HasEmittedInProgress bool // Whether we've emitted response.in_progress + ToolPlanOutputIndex *int // Output index for tool plan text item (if created) } // cohereResponsesStreamStatePool provides a pool for Cohere responses stream state objects. var cohereResponsesStreamStatePool = sync.Pool{ New: func() interface{} { return &CohereResponsesStreamState{ - ContentIndexToOutputIndex: make(map[int]int), - ToolArgumentBuffers: make(map[int]string), - ItemIDs: make(map[int]string), - CurrentOutputIndex: 0, - CreatedAt: int(time.Now().Unix()), - HasEmittedCreated: false, - HasEmittedInProgress: false, - ToolPlanOutputIndex: nil, + ContentIndexToOutputIndex: make(map[int]int), + ToolArgumentBuffers: make(map[int]string), + ItemIDs: make(map[int]string), + ReasoningContentIndices: make(map[int]bool), + AnnotationIndexToContentIndex: make(map[int]int), + CurrentOutputIndex: 0, + CreatedAt: int(time.Now().Unix()), + HasEmittedCreated: false, + HasEmittedInProgress: false, + ToolPlanOutputIndex: nil, } }, } @@ -59,6 +64,16 @@ func acquireCohereResponsesStreamState() *CohereResponsesStreamState { } else { clear(state.ItemIDs) } + if state.ReasoningContentIndices == nil { + state.ReasoningContentIndices = make(map[int]bool) + } else { + clear(state.ReasoningContentIndices) + } + if state.AnnotationIndexToContentIndex == nil { + state.AnnotationIndexToContentIndex = make(map[int]int) + } else { + clear(state.AnnotationIndexToContentIndex) + } // Reset other fields state.CurrentOutputIndex = 0 state.MessageID = nil @@ -96,6 +111,16 @@ func (state *CohereResponsesStreamState) flush() { } else { clear(state.ItemIDs) } + if state.ReasoningContentIndices == nil { + state.ReasoningContentIndices = make(map[int]bool) + } else { + clear(state.ReasoningContentIndices) + } + if state.AnnotationIndexToContentIndex == nil { + state.AnnotationIndexToContentIndex = make(map[int]int) + } else { + clear(state.AnnotationIndexToContentIndex) + } state.CurrentOutputIndex = 0 state.MessageID = nil state.Model = nil @@ -125,494 +150,433 @@ func (state *CohereResponsesStreamState) getOrCreateOutputIndex(contentIndex *in return outputIndex } -// ToCohereResponsesRequest converts a BifrostRequest (Responses structure) to CohereChatRequest -func ToCohereResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *CohereChatRequest { - if bifrostReq == nil { - return nil - } - - cohereReq := &CohereChatRequest{ - Model: bifrostReq.Model, - } - - // Map basic parameters - if bifrostReq.Params != nil { - if bifrostReq.Params.MaxOutputTokens != nil { - cohereReq.MaxTokens = bifrostReq.Params.MaxOutputTokens +// convertCohereContentBlockToBifrost converts CohereContentBlock to schemas.ContentBlock for Responses +func convertCohereContentBlockToBifrost(cohereBlock CohereContentBlock) schemas.ResponsesMessageContentBlock { + switch cohereBlock.Type { + case CohereContentBlockTypeText: + return schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: cohereBlock.Text, } - if bifrostReq.Params.Temperature != nil { - cohereReq.Temperature = bifrostReq.Params.Temperature + case CohereContentBlockTypeImage: + // For images, create a text block describing the image (should never happen) + if cohereBlock.ImageURL == nil { + // Skip invalid image blocks without ImageURL + return schemas.ResponsesMessageContentBlock{} } - if bifrostReq.Params.TopP != nil { - cohereReq.P = bifrostReq.Params.TopP + return schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeImage, + ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{ + ImageURL: &cohereBlock.ImageURL.URL, + }, } - - // Convert response_format from Text.Format to Cohere format - if bifrostReq.Params.Text != nil && bifrostReq.Params.Text.Format != nil { - cohereReq.ResponseFormat = convertResponsesTextFormatToCohere(bifrostReq.Params.Text.Format) + case CohereContentBlockTypeThinking: + return schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: cohereBlock.Thinking, } + default: + // Fallback to text block + return schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: schemas.Ptr(string(cohereBlock.Type)), + } + } +} - if bifrostReq.Params.ExtraParams != nil { - if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok { - cohereReq.K = topK - } - if stop, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["stop"]); ok { - cohereReq.StopSequences = stop - } - if frequencyPenalty, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["frequency_penalty"]); ok { - cohereReq.FrequencyPenalty = frequencyPenalty - } - if presencePenalty, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["presence_penalty"]); ok { - cohereReq.PresencePenalty = presencePenalty +func (chunk *CohereStreamEvent) ToBifrostResponsesStream(sequenceNumber int, state *CohereResponsesStreamState) ([]*schemas.BifrostResponsesStreamResponse, *schemas.BifrostError, bool) { + switch chunk.Type { + case StreamEventMessageStart: + // Message start - emit response.created and response.in_progress (OpenAI-style lifecycle) + if chunk.ID != nil { + state.MessageID = chunk.ID + // Use the state's CreatedAt for consistency + if state.CreatedAt == 0 { + state.CreatedAt = int(time.Now().Unix()) } - if thinkingParam, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "thinking"); ok { - if thinkingMap, ok := thinkingParam.(map[string]interface{}); ok { - thinking := &CohereThinking{} - if typeStr, ok := schemas.SafeExtractString(thinkingMap["type"]); ok { - thinking.Type = CohereThinkingType(typeStr) - } - if tokenBudget, ok := schemas.SafeExtractIntPointer(thinkingMap["token_budget"]); ok { - thinking.TokenBudget = tokenBudget - } - cohereReq.Thinking = thinking + + var responses []*schemas.BifrostResponsesStreamResponse + + // Emit response.created + if !state.HasEmittedCreated { + response := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCreated, + SequenceNumber: sequenceNumber, + Response: response, + }) + state.HasEmittedCreated = true } - } - } - // Convert tools - if bifrostReq.Params != nil && bifrostReq.Params.Tools != nil { - var cohereTools []CohereChatRequestTool - for _, tool := range bifrostReq.Params.Tools { - if tool.ResponsesToolFunction != nil && tool.Name != nil { - cohereTool := CohereChatRequestTool{ - Type: "function", - Function: CohereChatRequestFunction{ - Name: *tool.Name, - Description: tool.Description, - Parameters: tool.ResponsesToolFunction.Parameters, - }, + // Emit response.in_progress + if !state.HasEmittedInProgress { + response := &schemas.BifrostResponsesResponse{ + ID: state.MessageID, + CreatedAt: state.CreatedAt, // Use same timestamp } - cohereTools = append(cohereTools, cohereTool) + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeInProgress, + SequenceNumber: sequenceNumber + len(responses), + Response: response, + }) + state.HasEmittedInProgress = true } - } - if len(cohereTools) > 0 { - cohereReq.Tools = cohereTools + if len(responses) > 0 { + return responses, nil, false + } } - } - - // Convert tool choice - if bifrostReq.Params != nil && bifrostReq.Params.ToolChoice != nil { - cohereReq.ToolChoice = convertBifrostToolChoiceToCohereToolChoice(*bifrostReq.Params.ToolChoice) - } - - // Process ResponsesInput (which contains the Responses items) - if bifrostReq.Input != nil { - cohereReq.Messages = convertResponsesMessagesToCohereMessages(bifrostReq.Input) - } - - return cohereReq -} - -// ToBifrostResponsesResponse converts CohereChatResponse to BifrostResponse (Responses structure) -func (response *CohereChatResponse) ToBifrostResponsesResponse() *schemas.BifrostResponsesResponse { - if response == nil { - return nil - } + case StreamEventContentStart: + // Content block start - emit output_item.added (OpenAI-style) + // First, close tool plan message item if it's still open + var responses []*schemas.BifrostResponsesStreamResponse + if state.ToolPlanOutputIndex != nil { + outputIndex := *state.ToolPlanOutputIndex + itemID := state.ItemIDs[outputIndex] - bifrostResp := &schemas.BifrostResponsesResponse{ - ID: schemas.Ptr(response.ID), - CreatedAt: int(time.Now().Unix()), // Set current timestamp - } + // Emit output_text.done (without accumulated text, just the event) + emptyText := "" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(0), + ItemID: &itemID, + Text: &emptyText, + }) - // Convert usage information - if response.Usage != nil { - usage := &schemas.ResponsesResponseUsage{} + // Emit content_part.done + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(0), + ItemID: &itemID, + }) - if response.Usage.Tokens != nil { - if response.Usage.Tokens.InputTokens != nil { - usage.InputTokens = *response.Usage.Tokens.InputTokens - } - if response.Usage.Tokens.OutputTokens != nil { - usage.OutputTokens = *response.Usage.Tokens.OutputTokens + // Emit output_item.done + statusCompleted := "completed" + doneItem := &schemas.ResponsesMessage{ + Status: &statusCompleted, } - usage.TotalTokens = usage.InputTokens + usage.OutputTokens - } - - if response.Usage.CachedTokens != nil { - usage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ - CachedTokens: *response.Usage.CachedTokens, + if itemID != "" { + doneItem.ID = &itemID } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(0), + Item: doneItem, + }) + state.ToolPlanOutputIndex = nil // Mark as closed } - bifrostResp.Usage = usage - } + if chunk.Delta != nil && chunk.Index != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Content != nil && chunk.Delta.Message.Content.CohereStreamContentObject != nil { + outputIndex := state.getOrCreateOutputIndex(chunk.Index) - // Convert output message to Responses format - if response.Message != nil { - outputMessages := convertCohereMessageToResponsesOutput(*response.Message) - bifrostResp.Output = outputMessages - } + switch chunk.Delta.Message.Content.CohereStreamContentObject.Type { + case CohereContentBlockTypeText: + // Text block - emit output_item.added with type "message" + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant - return bifrostResp -} + // Generate stable ID for text item + var itemID string + if state.MessageID == nil { + itemID = fmt.Sprintf("item_%d", outputIndex) + } else { + itemID = fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) + } + state.ItemIDs[outputIndex] = itemID -// Helper functions + item := &schemas.ResponsesMessage{ + ID: &itemID, + Type: &messageType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{}, // Empty blocks slice for mutation support + }, + } -// convertBifrostToolChoiceToCohere converts schemas.ToolChoice to CohereToolChoice -func convertBifrostToolChoiceToCohereToolChoice(toolChoice schemas.ResponsesToolChoice) *CohereToolChoice { - toolChoiceString := toolChoice.ResponsesToolChoiceStr + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }) - if toolChoiceString != nil { - switch *toolChoiceString { - case "none": - choice := ToolChoiceNone - return &choice - case "required", "auto", "function": - choice := ToolChoiceRequired - return &choice - default: - choice := ToolChoiceRequired - return &choice - } - } + // Emit content_part.added with empty output_text part + emptyText := "" + part := &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &emptyText, + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + ItemID: &itemID, + Part: part, + }) + return responses, nil, false + case CohereContentBlockTypeThinking: + // Thinking/reasoning content - emit as reasoning item + messageType := schemas.ResponsesMessageTypeReasoning + role := schemas.ResponsesInputMessageRoleAssistant - return nil -} + // Generate stable ID for reasoning item + itemID := "rs_" + providerUtils.GetRandomString(50) + state.ItemIDs[outputIndex] = itemID -// convertResponsesMessagesToCohereMessages converts Responses items to Cohere messages -func convertResponsesMessagesToCohereMessages(messages []schemas.ResponsesMessage) []CohereMessage { - var cohereMessages []CohereMessage - var systemContent []string + item := &schemas.ResponsesMessage{ + ID: &itemID, + Type: &messageType, + Role: &role, + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + }, + } - for _, msg := range messages { - // Handle nil Type with default - msgType := schemas.ResponsesMessageTypeMessage - if msg.Type != nil { - msgType = *msg.Type - } + // Track that this content index is a reasoning block + if chunk.Index != nil { + state.ReasoningContentIndices[*chunk.Index] = true + } - switch msgType { - case schemas.ResponsesMessageTypeMessage: - // Handle nil Role with default - role := "user" - if msg.Role != nil { - role = string(*msg.Role) + // Emit output_item.added + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }) + + // Emit content_part.added with empty reasoning_text part + emptyText := "" + part := &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: &emptyText, + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + ItemID: &itemID, + Part: part, + }) + + return responses, nil, false } + } + if len(responses) > 0 { + return responses, nil, false + } + case StreamEventContentDelta: + if chunk.Index != nil && chunk.Delta != nil { + outputIndex := state.getOrCreateOutputIndex(chunk.Index) - if role == "system" { - // Collect system messages separately for Cohere - if msg.Content != nil { - if msg.Content.ContentStr != nil { - systemContent = append(systemContent, *msg.Content.ContentStr) - } else if msg.Content.ContentBlocks != nil { - for _, block := range msg.Content.ContentBlocks { - if block.Text != nil { - systemContent = append(systemContent, *block.Text) - } - } - } + // Handle text content delta + if chunk.Delta.Message != nil && chunk.Delta.Message.Content != nil && chunk.Delta.Message.Content.CohereStreamContentObject != nil && chunk.Delta.Message.Content.CohereStreamContentObject.Text != nil && *chunk.Delta.Message.Content.CohereStreamContentObject.Text != "" { + // Emit output_text.delta (not reasoning_summary_text.delta for regular text) + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Delta: chunk.Delta.Message.Content.CohereStreamContentObject.Text, } - } else { - cohereMsg := CohereMessage{ - Role: role, + if itemID != "" { + response.ItemID = &itemID } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + } - // Convert content - only if Content is not nil - if msg.Content != nil { - if msg.Content.ContentStr != nil { - cohereMsg.Content = NewStringContent(*msg.Content.ContentStr) - } else if msg.Content.ContentBlocks != nil { - contentBlocks := convertResponsesMessageContentBlocksToCohere(msg.Content.ContentBlocks) - cohereMsg.Content = NewBlocksContent(contentBlocks) - } + // Handle thinking content delta + if chunk.Delta.Message != nil && chunk.Delta.Message.Content != nil && chunk.Delta.Message.Content.CohereStreamContentObject != nil && chunk.Delta.Message.Content.CohereStreamContentObject.Thinking != nil && *chunk.Delta.Message.Content.CohereStreamContentObject.Thinking != "" { + // Emit reasoning_summary_text.delta for thinking content + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Delta: chunk.Delta.Message.Content.CohereStreamContentObject.Thinking, } - - cohereMessages = append(cohereMessages, cohereMsg) + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false } + } + return nil, nil, false + case StreamEventContentEnd: + // Content block is complete - emit output_text.done, content_part.done, and output_item.done (OpenAI-style) + if chunk.Index != nil { + outputIndex := state.getOrCreateOutputIndex(chunk.Index) + itemID := state.ItemIDs[outputIndex] + var responses []*schemas.BifrostResponsesStreamResponse - case schemas.ResponsesMessageTypeFunctionCall: - // Handle function calls from Responses - assistantMsg := CohereMessage{ - Role: "assistant", - } + // Check if this content index is a reasoning block + if state.ReasoningContentIndices[*chunk.Index] { + // Emit reasoning_summary_text.done (reasoning equivalent of output_text.done) + emptyText := "" + reasoningDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeReasoningSummaryTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Text: &emptyText, + } + if itemID != "" { + reasoningDoneResponse.ItemID = &itemID + } + responses = append(responses, reasoningDoneResponse) - // Extract function call details - var cohereToolCalls []CohereToolCall - toolCall := CohereToolCall{ - Type: "function", - Function: &CohereFunction{}, - } + // Emit content_part.done for reasoning + partDoneResponse := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + } + if itemID != "" { + partDoneResponse.ItemID = &itemID + } + responses = append(responses, partDoneResponse) - if msg.CallID != nil { - toolCall.ID = msg.CallID - } + // Clear the reasoning content index tracking + delete(state.ReasoningContentIndices, *chunk.Index) + } else { + // Regular text block - emit output_text.done (without accumulated text, just the event) + emptyText := "" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + ItemID: &itemID, + Text: &emptyText, + }) - // Get function details from AssistantMessage - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Arguments != nil { - toolCall.Function.Arguments = *msg.ResponsesToolMessage.Arguments + // Emit content_part.done + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + ItemID: &itemID, + }) } - // Get name from ToolMessage if available - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { - toolCall.Function.Name = msg.ResponsesToolMessage.Name + // Emit output_item.done for all content blocks (text, reasoning, etc.) + statusCompleted := "completed" + doneItem := &schemas.ResponsesMessage{ + Status: &statusCompleted, } - - cohereToolCalls = append(cohereToolCalls, toolCall) - - if len(cohereToolCalls) > 0 { - assistantMsg.ToolCalls = cohereToolCalls + if itemID != "" { + doneItem.ID = &itemID } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: doneItem, + }) + return responses, nil, false + } + case StreamEventToolPlanDelta: + if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolPlan != nil && *chunk.Delta.Message.ToolPlan != "" { + // Tool plan delta - treat as normal text (Option A) + // Use output_index 0 for text message if it exists, otherwise create new + outputIndex := 0 + var responses []*schemas.BifrostResponsesStreamResponse - cohereMessages = append(cohereMessages, assistantMsg) + if state.ToolPlanOutputIndex != nil { + outputIndex = *state.ToolPlanOutputIndex + } else { + // Create message item first if it doesn't exist + outputIndex = 0 + state.ToolPlanOutputIndex = &outputIndex + state.ContentIndexToOutputIndex[0] = outputIndex - case "function_call_output": - // Handle function call outputs - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { - toolMsg := CohereMessage{ - Role: "tool", + // Generate stable ID for text item + // Generate stable ID for text item + var itemID string + if state.MessageID == nil { + itemID = fmt.Sprintf("item_%d", outputIndex) + } else { + itemID = fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) } + state.ItemIDs[outputIndex] = itemID - // Extract content from ResponsesFunctionToolCallOutput if Content is not set - // This is needed for OpenAI Responses API which uses an "output" field - content := msg.Content - if content == nil && msg.ResponsesToolMessage.Output != nil { - content = &schemas.ResponsesMessageContent{} - if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { - content.ContentStr = msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr - } else if msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { - content.ContentBlocks = msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks - } - } + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant - // Convert content - only if Content is not nil - if content != nil { - if content.ContentStr != nil { - toolMsg.Content = NewStringContent(*content.ContentStr) - } else if content.ContentBlocks != nil { - contentBlocks := convertResponsesMessageContentBlocksToCohere(content.ContentBlocks) - toolMsg.Content = NewBlocksContent(contentBlocks) - } + item := &schemas.ResponsesMessage{ + ID: &itemID, + Type: &messageType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{}, + }, } - toolMsg.ToolCallID = msg.ResponsesToolMessage.CallID + // Emit output_item.added for text message + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(0), + Item: item, + }) - cohereMessages = append(cohereMessages, toolMsg) + // Emit content_part.added with empty output_text part + emptyText := "" + part := &schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &emptyText, + } + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeContentPartAdded, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(0), + ItemID: &itemID, + Part: part, + }) } - } - } - // Prepend system messages if any - if len(systemContent) > 0 { - systemMsg := CohereMessage{ - Role: "system", - Content: NewStringContent(strings.Join(systemContent, "\n")), + // Emit output_text.delta (not reasoning_summary_text.delta) + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: schemas.Ptr(0), // Tool plan is typically at index 0 + Delta: chunk.Delta.Message.ToolPlan, + } + if itemID != "" { + response.ItemID = &itemID + } + responses = append(responses, response) + return responses, nil, false } - cohereMessages = append([]CohereMessage{systemMsg}, cohereMessages...) - } - - return cohereMessages -} - -// convertBifrostContentBlocksToCohere converts Bifrost content blocks to Cohere format -func convertResponsesMessageContentBlocksToCohere(blocks []schemas.ResponsesMessageContentBlock) []CohereContentBlock { - var cohereBlocks []CohereContentBlock - - for _, block := range blocks { - switch block.Type { - case schemas.ResponsesInputMessageContentBlockTypeText, schemas.ResponsesOutputMessageContentTypeText: - // Handle both input_text (user messages) and output_text (assistant messages) - if block.Text != nil { - cohereBlocks = append(cohereBlocks, CohereContentBlock{ - Type: CohereContentBlockTypeText, - Text: block.Text, - }) - } - case schemas.ResponsesInputMessageContentBlockTypeImage: - if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil && *block.ResponsesInputMessageContentBlockImage.ImageURL != "" { - cohereBlocks = append(cohereBlocks, CohereContentBlock{ - Type: CohereContentBlockTypeImage, - ImageURL: &CohereImageURL{ - URL: *block.ResponsesInputMessageContentBlockImage.ImageURL, - }, - }) - } - case schemas.ResponsesOutputMessageContentTypeReasoning: - if block.Text != nil { - cohereBlocks = append(cohereBlocks, CohereContentBlock{ - Type: CohereContentBlockTypeThinking, - Thinking: block.Text, - }) - } - } - } - - return cohereBlocks -} - -// convertCohereMessageToResponsesOutput converts Cohere message to Responses output format -func convertCohereMessageToResponsesOutput(cohereMsg CohereMessage) []schemas.ResponsesMessage { - var outputMessages []schemas.ResponsesMessage - - // Handle text content first - if cohereMsg.Content != nil { - var content schemas.ResponsesMessageContent - - var contentBlocks []schemas.ResponsesMessageContentBlock - - if cohereMsg.Content.StringContent != nil { - contentBlocks = append(contentBlocks, schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: cohereMsg.Content.StringContent, - }) - } else if cohereMsg.Content.BlocksContent != nil { - // Convert content blocks - for _, block := range cohereMsg.Content.BlocksContent { - contentBlocks = append(contentBlocks, convertCohereContentBlockToBifrost(block)) - } - } - content.ContentBlocks = contentBlocks - - // Create message output - if content.ContentBlocks != nil { - outputMsg := schemas.ResponsesMessage{ - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &content, - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - } - - outputMessages = append(outputMessages, outputMsg) - } - } - - // Handle tool calls - if cohereMsg.ToolCalls != nil { - for _, toolCall := range cohereMsg.ToolCalls { - // Check if Function is nil to avoid nil pointer dereference - if toolCall.Function == nil { - // Skip this tool call if Function is nil - continue - } - - // Safely extract function name and arguments - var functionName *string - var functionArguments *string - - if toolCall.Function.Name != nil { - functionName = toolCall.Function.Name - } else { - // Use empty string if Name is nil - functionName = schemas.Ptr("") - } - - // Arguments is a string, not a pointer, so it's safe to access directly - functionArguments = schemas.Ptr(toolCall.Function.Arguments) - - toolCallMsg := schemas.ResponsesMessage{ - ID: toolCall.ID, - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - Status: schemas.Ptr("completed"), - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - Name: functionName, - CallID: toolCall.ID, - Arguments: functionArguments, - }, - } - - outputMessages = append(outputMessages, toolCallMsg) - } - } - - return outputMessages -} - -// convertCohereContentBlockToBifrost converts CohereContentBlock to schemas.ContentBlock for Responses -func convertCohereContentBlockToBifrost(cohereBlock CohereContentBlock) schemas.ResponsesMessageContentBlock { - switch cohereBlock.Type { - case CohereContentBlockTypeText: - return schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: cohereBlock.Text, - } - case CohereContentBlockTypeImage: - // For images, create a text block describing the image - if cohereBlock.ImageURL == nil { - // Skip invalid image blocks without ImageURL - return schemas.ResponsesMessageContentBlock{} - } - return schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesInputMessageContentBlockTypeImage, - ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{ - ImageURL: &cohereBlock.ImageURL.URL, - }, - } - case CohereContentBlockTypeThinking: - return schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeReasoning, - Text: cohereBlock.Thinking, - } - default: - // Fallback to text block - return schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesInputMessageContentBlockTypeText, - Text: schemas.Ptr(string(cohereBlock.Type)), - } - } -} - -func (chunk *CohereStreamEvent) ToBifrostResponsesStream(sequenceNumber int, state *CohereResponsesStreamState) ([]*schemas.BifrostResponsesStreamResponse, *schemas.BifrostError, bool) { - switch chunk.Type { - case StreamEventMessageStart: - // Message start - emit response.created and response.in_progress (OpenAI-style lifecycle) - if chunk.ID != nil { - state.MessageID = chunk.ID - // Use the state's CreatedAt for consistency - if state.CreatedAt == 0 { - state.CreatedAt = int(time.Now().Unix()) - } - - var responses []*schemas.BifrostResponsesStreamResponse - - // Emit response.created - if !state.HasEmittedCreated { - response := &schemas.BifrostResponsesResponse{ - ID: state.MessageID, - CreatedAt: state.CreatedAt, - } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeCreated, - SequenceNumber: sequenceNumber, - Response: response, - }) - state.HasEmittedCreated = true - } - - // Emit response.in_progress - if !state.HasEmittedInProgress { - response := &schemas.BifrostResponsesResponse{ - ID: state.MessageID, - CreatedAt: state.CreatedAt, // Use same timestamp - } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeInProgress, - SequenceNumber: sequenceNumber + len(responses), - Response: response, - }) - state.HasEmittedInProgress = true - } - - if len(responses) > 0 { - return responses, nil, false - } - } - case StreamEventContentStart: - // Content block start - emit output_item.added (OpenAI-style) - // First, close tool plan message item if it's still open - var responses []*schemas.BifrostResponsesStreamResponse - if state.ToolPlanOutputIndex != nil { - outputIndex := *state.ToolPlanOutputIndex - itemID := state.ItemIDs[outputIndex] + return nil, nil, false + case StreamEventToolCallStart: + // First, close tool plan message item if it's still open + var responses []*schemas.BifrostResponsesStreamResponse + if state.ToolPlanOutputIndex != nil { + outputIndex := *state.ToolPlanOutputIndex + itemID := state.ItemIDs[outputIndex] // Emit output_text.done (without accumulated text, just the event) emptyText := "" @@ -652,107 +616,77 @@ func (chunk *CohereStreamEvent) ToBifrostResponsesStream(sequenceNumber int, sta state.ToolPlanOutputIndex = nil // Mark as closed } - if chunk.Delta != nil && chunk.Index != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Content != nil && chunk.Delta.Message.Content.CohereStreamContentObject != nil { - outputIndex := state.getOrCreateOutputIndex(chunk.Index) - - switch chunk.Delta.Message.Content.CohereStreamContentObject.Type { - case CohereContentBlockTypeText: - // Text block - emit output_item.added with type "message" - messageType := schemas.ResponsesMessageTypeMessage - role := schemas.ResponsesInputMessageRoleAssistant - - // Generate stable ID for text item - var itemID string - if state.MessageID == nil { - itemID = fmt.Sprintf("item_%d", outputIndex) - } else { - itemID = fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) + if chunk.Index != nil && chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil { + // Tool call start - emit output_item.added with type "function_call" and status "in_progress" + toolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject + if toolCall.Function != nil && toolCall.Function.Name != nil { + // Always use a new output index for tool calls to avoid collision with text items + // Use output_index 1 (or next available) to avoid collision with text at index 0 + outputIndex := state.CurrentOutputIndex + if outputIndex == 0 { + outputIndex = 1 // Skip 0 if it's used for text } - if state.MessageID == nil { - itemID = fmt.Sprintf("item_%d", outputIndex) + state.CurrentOutputIndex = outputIndex + 1 + // Optionally map the content index if provided + if chunk.Index != nil { + state.ContentIndexToOutputIndex[*chunk.Index] = outputIndex + } + + statusInProgress := "in_progress" + itemID := "" + if toolCall.ID != nil { + itemID = *toolCall.ID + state.ItemIDs[outputIndex] = itemID } - state.ItemIDs[outputIndex] = itemID item := &schemas.ResponsesMessage{ - ID: &itemID, - Type: &messageType, - Role: &role, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{}, // Empty blocks slice for mutation support + ID: toolCall.ID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: &statusInProgress, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: schemas.Ptr(""), // Arguments will be filled by deltas }, } + // Initialize argument buffer for this tool call + state.ToolArgumentBuffers[outputIndex] = "" + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, SequenceNumber: sequenceNumber + len(responses), OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, Item: item, }) + return responses, nil, false + } + } + if len(responses) > 0 { + return responses, nil, false + } + return nil, nil, false + case StreamEventToolCallDelta: + if chunk.Index != nil && chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil { + // Tool call delta - handle function arguments streaming + toolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject + if toolCall.Function != nil { + outputIndex := state.getOrCreateOutputIndex(chunk.Index) - // Emit content_part.added with empty output_text part - emptyText := "" - part := &schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: &emptyText, - } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartAdded, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - ItemID: &itemID, - Part: part, - }) - return responses, nil, false - case CohereContentBlockTypeThinking: - // Thinking/reasoning content - emit as reasoning item - messageType := schemas.ResponsesMessageTypeReasoning - role := schemas.ResponsesInputMessageRoleAssistant - - // Generate stable ID for reasoning item - itemID := fmt.Sprintf("msg_%s_reasoning_%d", *state.MessageID, outputIndex) - if state.MessageID == nil { - itemID = fmt.Sprintf("reasoning_%d", outputIndex) - } - state.ItemIDs[outputIndex] = itemID - - item := &schemas.ResponsesMessage{ - ID: &itemID, - Type: &messageType, - Role: &role, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{}, - }, + // Accumulate tool arguments in buffer + if _, exists := state.ToolArgumentBuffers[outputIndex]; !exists { + state.ToolArgumentBuffers[outputIndex] = "" } + state.ToolArgumentBuffers[outputIndex] += toolCall.Function.Arguments - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Item: item, - }) - return responses, nil, false - } - } - if len(responses) > 0 { - return responses, nil, false - } - case StreamEventContentDelta: - if chunk.Index != nil && chunk.Delta != nil { - outputIndex := state.getOrCreateOutputIndex(chunk.Index) - - // Handle text content delta - if chunk.Delta.Message != nil && chunk.Delta.Message.Content != nil && chunk.Delta.Message.Content.CohereStreamContentObject != nil && chunk.Delta.Message.Content.CohereStreamContentObject.Text != nil && *chunk.Delta.Message.Content.CohereStreamContentObject.Text != "" { - // Emit output_text.delta (not reasoning_summary_text.delta for regular text) + // Emit function_call_arguments.delta itemID := state.ItemIDs[outputIndex] response := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, + Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta, SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), ContentIndex: chunk.Index, - Delta: chunk.Delta.Message.Content.CohereStreamContentObject.Text, + OutputIndex: schemas.Ptr(outputIndex), + Delta: schemas.Ptr(toolCall.Function.Arguments), } if itemID != "" { response.ItemID = &itemID @@ -761,35 +695,33 @@ func (chunk *CohereStreamEvent) ToBifrostResponsesStream(sequenceNumber int, sta } } return nil, nil, false - case StreamEventContentEnd: - // Content block is complete - emit output_text.done, content_part.done, and output_item.done (OpenAI-style) + case StreamEventToolCallEnd: if chunk.Index != nil { + // Tool call end - emit function_call_arguments.done then output_item.done outputIndex := state.getOrCreateOutputIndex(chunk.Index) - itemID := state.ItemIDs[outputIndex] var responses []*schemas.BifrostResponsesStreamResponse - // Emit output_text.done (without accumulated text, just the event) - emptyText := "" - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - ItemID: &itemID, - Text: &emptyText, - }) - - // Emit content_part.done - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - ItemID: &itemID, - }) + // Emit function_call_arguments.done with full accumulated JSON + if accumulatedArgs, hasArgs := state.ToolArgumentBuffers[outputIndex]; hasArgs && accumulatedArgs != "" { + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Arguments: &accumulatedArgs, + } + if itemID != "" { + response.ItemID = &itemID + } + responses = append(responses, response) + // Clear the buffer + delete(state.ToolArgumentBuffers, outputIndex) + } - // Emit output_item.done + // Emit output_item.done for the function call statusCompleted := "completed" + itemID := state.ItemIDs[outputIndex] doneItem := &schemas.ResponsesMessage{ Status: &statusCompleted, } @@ -803,411 +735,835 @@ func (chunk *CohereStreamEvent) ToBifrostResponsesStream(sequenceNumber int, sta ContentIndex: chunk.Index, Item: doneItem, }) + return responses, nil, false } - case StreamEventToolPlanDelta: - if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolPlan != nil && *chunk.Delta.Message.ToolPlan != "" { - // Tool plan delta - treat as normal text (Option A) - // Use output_index 0 for text message if it exists, otherwise create new - outputIndex := 0 - var responses []*schemas.BifrostResponsesStreamResponse - - if state.ToolPlanOutputIndex != nil { - outputIndex = *state.ToolPlanOutputIndex - } else { - // Create message item first if it doesn't exist - outputIndex = 0 - state.ToolPlanOutputIndex = &outputIndex - state.ContentIndexToOutputIndex[0] = outputIndex + return nil, nil, false + case StreamEventCitationStart: + if chunk.Index != nil && chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Citations != nil { + // Citation start - create annotation for the citation + citation := chunk.Delta.Message.Citations.CohereStreamCitationObject - // Generate stable ID for text item - // Generate stable ID for text item - var itemID string - if state.MessageID == nil { - itemID = fmt.Sprintf("item_%d", outputIndex) - } else { - itemID = fmt.Sprintf("msg_%s_item_%d", *state.MessageID, outputIndex) - } - state.ItemIDs[outputIndex] = itemID + // Map Cohere citation to ResponsesOutputMessageContentTextAnnotation + annotation := &schemas.ResponsesOutputMessageContentTextAnnotation{ + Type: "file_citation", // Default to file_citation + StartIndex: schemas.Ptr(citation.Start), + EndIndex: schemas.Ptr(citation.End), + } - messageType := schemas.ResponsesMessageTypeMessage - role := schemas.ResponsesInputMessageRoleAssistant + // Set annotation type and metadata + if len(citation.Sources) > 0 { + source := citation.Sources[0] - item := &schemas.ResponsesMessage{ - ID: &itemID, - Type: &messageType, - Role: &role, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{}, - }, + if source.ID != nil { + annotation.FileID = source.ID } - // Emit output_item.added for text message - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - Item: item, - }) - - // Emit content_part.added with empty output_text part - emptyText := "" - part := &schemas.ResponsesMessageContentBlock{ - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: &emptyText, + if source.Document != nil { + if title, ok := (*source.Document)["title"].(string); ok { + annotation.Title = &title + } + if id, ok := (*source.Document)["id"].(string); ok && annotation.FileID == nil { + annotation.FileID = &id + } + if snippet, ok := (*source.Document)["snippet"].(string); ok { + annotation.Text = &snippet + } + if url, ok := (*source.Document)["url"].(string); ok { + annotation.URL = &url + } } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartAdded, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - ItemID: &itemID, - Part: part, - }) } - // Emit output_text.delta (not reasoning_summary_text.delta) - itemID := state.ItemIDs[outputIndex] - response := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), // Tool plan is typically at index 0 - Delta: chunk.Delta.Message.ToolPlan, + // Use output_index based on content index for citations (they're part of the text item) + outputIndex := 0 + if citation.ContentIndex >= 0 { + contentIndexPtr := &citation.ContentIndex + outputIndex = state.getOrCreateOutputIndex(contentIndexPtr) } - if itemID != "" { - response.ItemID = &itemID + + // Record mapping from annotation index to content index for citation pairing + if chunk.Index != nil && citation.ContentIndex >= 0 { + state.AnnotationIndexToContentIndex[*chunk.Index] = citation.ContentIndex } - responses = append(responses, response) - return responses, nil, false + + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypeOutputTextAnnotationAdded, + SequenceNumber: sequenceNumber, + ContentIndex: schemas.Ptr(citation.ContentIndex), + Annotation: annotation, + OutputIndex: schemas.Ptr(outputIndex), + AnnotationIndex: chunk.Index, + }}, nil, false } return nil, nil, false - case StreamEventToolCallStart: - // First, close tool plan message item if it's still open - var responses []*schemas.BifrostResponsesStreamResponse - if state.ToolPlanOutputIndex != nil { - outputIndex := *state.ToolPlanOutputIndex - itemID := state.ItemIDs[outputIndex] - - // Emit output_text.done (without accumulated text, just the event) - emptyText := "" - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputTextDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - ItemID: &itemID, - Text: &emptyText, - }) + case StreamEventCitationEnd: + if chunk.Index != nil { + // Citation end - indicate annotation is complete + // Look up the original content index from state using the annotation index + contentIndex, exists := state.AnnotationIndexToContentIndex[*chunk.Index] + if !exists { + // Fallback: if mapping not found, use annotation index (shouldn't happen in normal flow) + contentIndex = *chunk.Index + } - // Emit content_part.done - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeContentPartDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - ItemID: &itemID, - }) + // Derive outputIndex from the content index + contentIndexPtr := &contentIndex + outputIndex := state.getOrCreateOutputIndex(contentIndexPtr) - // Emit output_item.done - statusCompleted := "completed" - doneItem := &schemas.ResponsesMessage{ - Status: &statusCompleted, - } - if itemID != "" { - doneItem.ID = &itemID - } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: schemas.Ptr(0), - Item: doneItem, - }) - state.ToolPlanOutputIndex = nil // Mark as closed + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypeOutputTextAnnotationDone, + SequenceNumber: sequenceNumber, + ContentIndex: &contentIndex, + OutputIndex: schemas.Ptr(outputIndex), + AnnotationIndex: chunk.Index, + }}, nil, false + } + return nil, nil, false + case StreamEventMessageEnd: + // Message end - emit response.completed (OpenAI-style) + response := &schemas.BifrostResponsesResponse{ + CreatedAt: state.CreatedAt, + } + if state.MessageID != nil { + response.ID = state.MessageID + } + if state.Model != nil { + response.Model = *state.Model } - if chunk.Index != nil && chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil { - // Tool call start - emit output_item.added with type "function_call" and status "in_progress" - toolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject - if toolCall.Function != nil && toolCall.Function.Name != nil { - // Always use a new output index for tool calls to avoid collision with text items - // Use output_index 1 (or next available) to avoid collision with text at index 0 - outputIndex := state.CurrentOutputIndex - if outputIndex == 0 { - outputIndex = 1 // Skip 0 if it's used for text - } - state.CurrentOutputIndex = outputIndex + 1 - // Optionally map the content index if provided - if chunk.Index != nil { - state.ContentIndexToOutputIndex[*chunk.Index] = outputIndex - } + if chunk.Delta != nil { + if chunk.Delta.Usage != nil { + usage := &schemas.ResponsesResponseUsage{} - statusInProgress := "in_progress" - itemID := "" - if toolCall.ID != nil { - itemID = *toolCall.ID - state.ItemIDs[outputIndex] = itemID + if chunk.Delta.Usage.Tokens != nil { + if chunk.Delta.Usage.Tokens.InputTokens != nil { + usage.InputTokens = *chunk.Delta.Usage.Tokens.InputTokens + } + if chunk.Delta.Usage.Tokens.OutputTokens != nil { + usage.OutputTokens = *chunk.Delta.Usage.Tokens.OutputTokens + } + usage.TotalTokens = usage.InputTokens + usage.OutputTokens } - item := &schemas.ResponsesMessage{ - ID: toolCall.ID, - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - Status: &statusInProgress, - ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: toolCall.ID, - Name: toolCall.Function.Name, - Arguments: schemas.Ptr(""), // Arguments will be filled by deltas - }, + if chunk.Delta.Usage.CachedTokens != nil { + usage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ + CachedTokens: *chunk.Delta.Usage.CachedTokens, + } } + response.Usage = usage + } + } - // Initialize argument buffer for this tool call - state.ToolArgumentBuffers[outputIndex] = "" + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypeCompleted, + SequenceNumber: sequenceNumber, + Response: response, + }}, nil, true + case StreamEventDebug: + return nil, nil, false + } + return nil, nil, false +} - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - Item: item, - }) - return responses, nil, false +// ConvertResponsesTextFormatToCohere converts Bifrost Responses Text.Format to Cohere's typed format +// Responses format: Text.Format with type "json_schema", "json_object", or "text" +// Cohere format: { type: "json_object", json_schema: {...} } +func convertResponsesTextFormatToCohere(textFormat *schemas.ResponsesTextConfigFormat) *CohereResponseFormat { + if textFormat == nil { + return nil + } + + cohereFormat := &CohereResponseFormat{} + + // Convert type + switch textFormat.Type { + case "text": + cohereFormat.Type = ResponseFormatTypeText + case "json_object": + cohereFormat.Type = ResponseFormatTypeJSONObject + case "json_schema": + cohereFormat.Type = ResponseFormatTypeJSONObject + + // If schema is provided, extract it + if textFormat.JSONSchema != nil { + // Build schema map + schema := make(map[string]interface{}) + if textFormat.JSONSchema.Type != nil { + schema["type"] = *textFormat.JSONSchema.Type } + if textFormat.JSONSchema.Properties != nil { + schema["properties"] = *textFormat.JSONSchema.Properties + } + if len(textFormat.JSONSchema.Required) > 0 { + schema["required"] = textFormat.JSONSchema.Required + } + if textFormat.JSONSchema.AdditionalProperties != nil { + schema["additionalProperties"] = *textFormat.JSONSchema.AdditionalProperties + } + + var schemaInterface interface{} = schema + cohereFormat.JSONSchema = &schemaInterface } - if len(responses) > 0 { - return responses, nil, false + default: + cohereFormat.Type = ResponseFormatTypeJSONObject + } + + return cohereFormat +} + +// ToCohereResponsesRequest converts a BifrostRequest (Responses structure) to CohereChatRequest +func ToCohereResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*CohereChatRequest, error) { + if bifrostReq == nil { + return nil, nil + } + + cohereReq := &CohereChatRequest{ + Model: bifrostReq.Model, + } + + // Map basic parameters + if bifrostReq.Params != nil { + if bifrostReq.Params.MaxOutputTokens != nil { + cohereReq.MaxTokens = bifrostReq.Params.MaxOutputTokens + } + if bifrostReq.Params.Temperature != nil { + cohereReq.Temperature = bifrostReq.Params.Temperature + } + if bifrostReq.Params.TopP != nil { + cohereReq.P = bifrostReq.Params.TopP } - return nil, nil, false - case StreamEventToolCallDelta: - if chunk.Index != nil && chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil { - // Tool call delta - handle function arguments streaming - toolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject - if toolCall.Function != nil { - outputIndex := state.getOrCreateOutputIndex(chunk.Index) - // Accumulate tool arguments in buffer - if _, exists := state.ToolArgumentBuffers[outputIndex]; !exists { - state.ToolArgumentBuffers[outputIndex] = "" + // Convert reasoning + if bifrostReq.Params.Reasoning != nil { + if bifrostReq.Params.Reasoning.MaxTokens != nil { + cohereReq.Thinking = &CohereThinking{ + Type: ThinkingTypeEnabled, + TokenBudget: bifrostReq.Params.Reasoning.MaxTokens, } - state.ToolArgumentBuffers[outputIndex] += toolCall.Function.Arguments + } else { + if bifrostReq.Params.Reasoning.Effort != nil && *bifrostReq.Params.Reasoning.Effort != "none" { + maxOutputTokens := DefaultCompletionMaxTokens + if bifrostReq.Params.MaxOutputTokens != nil { + maxOutputTokens = *bifrostReq.Params.MaxOutputTokens + } + budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, MinimumReasoningMaxTokens, maxOutputTokens) + if err != nil { + return nil, err + } + cohereReq.Thinking = &CohereThinking{ + Type: ThinkingTypeEnabled, + TokenBudget: schemas.Ptr(budgetTokens), + } + } else { + cohereReq.Thinking = &CohereThinking{ + Type: ThinkingTypeDisabled, + } + } + } + } - // Emit function_call_arguments.delta - itemID := state.ItemIDs[outputIndex] - response := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta, - SequenceNumber: sequenceNumber, - ContentIndex: chunk.Index, - OutputIndex: schemas.Ptr(outputIndex), - Delta: schemas.Ptr(toolCall.Function.Arguments), + if bifrostReq.Params.Text != nil && bifrostReq.Params.Text.Format != nil { + cohereReq.ResponseFormat = convertResponsesTextFormatToCohere(bifrostReq.Params.Text.Format) + } + if bifrostReq.Params.ExtraParams != nil { + if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok { + cohereReq.K = topK + } + if stop, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["stop"]); ok { + cohereReq.StopSequences = stop + } + if frequencyPenalty, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["frequency_penalty"]); ok { + cohereReq.FrequencyPenalty = frequencyPenalty + } + if presencePenalty, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["presence_penalty"]); ok { + cohereReq.PresencePenalty = presencePenalty + } + if thinkingParam, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "thinking"); ok { + if thinkingMap, ok := thinkingParam.(map[string]interface{}); ok { + thinking := &CohereThinking{} + if typeStr, ok := schemas.SafeExtractString(thinkingMap["type"]); ok { + thinking.Type = CohereThinkingType(typeStr) + } + if tokenBudget, ok := schemas.SafeExtractIntPointer(thinkingMap["token_budget"]); ok { + thinking.TokenBudget = tokenBudget + } + cohereReq.Thinking = thinking } - if itemID != "" { - response.ItemID = &itemID + } + } + } + + // Convert tools + if bifrostReq.Params != nil && bifrostReq.Params.Tools != nil { + var cohereTools []CohereChatRequestTool + for _, tool := range bifrostReq.Params.Tools { + if tool.ResponsesToolFunction != nil && tool.Name != nil { + cohereTool := CohereChatRequestTool{ + Type: "function", + Function: CohereChatRequestFunction{ + Name: *tool.Name, + Description: tool.Description, + Parameters: tool.ResponsesToolFunction.Parameters, + }, } - return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + cohereTools = append(cohereTools, cohereTool) } } - return nil, nil, false - case StreamEventToolCallEnd: - if chunk.Index != nil { - // Tool call end - emit function_call_arguments.done then output_item.done - outputIndex := state.getOrCreateOutputIndex(chunk.Index) - var responses []*schemas.BifrostResponsesStreamResponse - // Emit function_call_arguments.done with full accumulated JSON - if accumulatedArgs, hasArgs := state.ToolArgumentBuffers[outputIndex]; hasArgs && accumulatedArgs != "" { - itemID := state.ItemIDs[outputIndex] - response := &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone, - SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Arguments: &accumulatedArgs, + if len(cohereTools) > 0 { + cohereReq.Tools = cohereTools + } + } + + // Convert tool choice + if bifrostReq.Params != nil && bifrostReq.Params.ToolChoice != nil { + cohereReq.ToolChoice = convertBifrostToolChoiceToCohereToolChoice(*bifrostReq.Params.ToolChoice) + } + + // Process ResponsesInput (which contains the Responses items) + if bifrostReq.Input != nil { + cohereReq.Messages = ConvertBifrostMessagesToCohereMessages(bifrostReq.Input) + } + + return cohereReq, nil +} + +// ToBifrostResponsesResponse converts CohereChatResponse to BifrostResponse (Responses structure) +func (response *CohereChatResponse) ToBifrostResponsesResponse() *schemas.BifrostResponsesResponse { + if response == nil { + return nil + } + + bifrostResp := &schemas.BifrostResponsesResponse{ + ID: schemas.Ptr(response.ID), + CreatedAt: int(time.Now().Unix()), // Set current timestamp + } + + // Convert usage information + if response.Usage != nil { + usage := &schemas.ResponsesResponseUsage{} + + if response.Usage.Tokens != nil { + if response.Usage.Tokens.InputTokens != nil { + usage.InputTokens = *response.Usage.Tokens.InputTokens + } + if response.Usage.Tokens.OutputTokens != nil { + usage.OutputTokens = *response.Usage.Tokens.OutputTokens + } + usage.TotalTokens = usage.InputTokens + usage.OutputTokens + } + + if response.Usage.CachedTokens != nil { + usage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ + CachedTokens: *response.Usage.CachedTokens, + } + } + + bifrostResp.Usage = usage + } + + // Convert output message to Responses format + if response.Message != nil { + outputMessages := ConvertCohereMessagesToBifrostMessages([]CohereMessage{*response.Message}, true) + bifrostResp.Output = outputMessages + } + + return bifrostResp +} + +// ConvertBifrostMessagesToCohereMessages converts an array of Bifrost ResponsesMessage to Cohere message format +// This is the main conversion method from Bifrost to Cohere - handles all message types and returns messages +func ConvertBifrostMessagesToCohereMessages(bifrostMessages []schemas.ResponsesMessage) []CohereMessage { + var cohereMessages []CohereMessage + var systemContent []string + var pendingReasoningContentBlocks []CohereContentBlock + var currentAssistantMessage *CohereMessage + + for _, msg := range bifrostMessages { + // Handle nil Type with default + msgType := schemas.ResponsesMessageTypeMessage + if msg.Type != nil { + msgType = *msg.Type + } + + switch msgType { + case schemas.ResponsesMessageTypeMessage: + // Handle nil Role with default + role := "user" + if msg.Role != nil { + role = string(*msg.Role) + } + + if role == "system" { + // Collect system messages separately for Cohere + systemMsgs := convertBifrostMessageToCohereSystemContent(&msg) + systemContent = append(systemContent, systemMsgs...) + } else { + // Convert regular message + cohereMsg := convertBifrostMessageToCohereMessage(&msg) + if cohereMsg != nil { + if role == "assistant" { + // Add any pending reasoning content blocks to the message + if len(pendingReasoningContentBlocks) > 0 { + // copy the pending reasoning content blocks + copied := make([]CohereContentBlock, len(pendingReasoningContentBlocks)) + copy(copied, pendingReasoningContentBlocks) + contentBlocks := copied + pendingReasoningContentBlocks = nil + // Add content blocks after pending reasoning content blocks are added + if msg.Content != nil { + if msg.Content.ContentStr != nil { + contentBlocks = append(contentBlocks, CohereContentBlock{ + Type: CohereContentBlockTypeText, + Text: msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + contentBlocks = append(contentBlocks, convertResponsesMessageContentBlocksToCohere(msg.Content.ContentBlocks)...) + } + } + cohereMsg.Content = NewBlocksContent(contentBlocks) + } + // Store assistant message for potential reasoning blocks + currentAssistantMessage = cohereMsg + } else { + // Flush any pending assistant message first for non-assistant messages + if currentAssistantMessage != nil { + if len(pendingReasoningContentBlocks) > 0 { + if currentAssistantMessage.Content == nil { + currentAssistantMessage.Content = NewBlocksContent(pendingReasoningContentBlocks) + } else if currentAssistantMessage.Content.BlocksContent != nil { + currentAssistantMessage.Content.BlocksContent = append(currentAssistantMessage.Content.BlocksContent, pendingReasoningContentBlocks...) + } + pendingReasoningContentBlocks = nil + } + cohereMessages = append(cohereMessages, *currentAssistantMessage) + currentAssistantMessage = nil + } + cohereMessages = append(cohereMessages, *cohereMsg) + } } - if itemID != "" { - response.ItemID = &itemID + } + + case schemas.ResponsesMessageTypeReasoning: + // Handle reasoning as thinking content blocks + reasoningBlocks := convertBifrostReasoningToCohereThinking(&msg) + if len(reasoningBlocks) > 0 { + if currentAssistantMessage == nil { + currentAssistantMessage = &CohereMessage{ + Role: "assistant", + } } - responses = append(responses, response) - // Clear the buffer - delete(state.ToolArgumentBuffers, outputIndex) + pendingReasoningContentBlocks = append(pendingReasoningContentBlocks, reasoningBlocks...) } - // Emit output_item.done for the function call - statusCompleted := "completed" - itemID := state.ItemIDs[outputIndex] - doneItem := &schemas.ResponsesMessage{ - Status: &statusCompleted, + case schemas.ResponsesMessageTypeFunctionCall: + // Flush any pending reasoning blocks first + if len(pendingReasoningContentBlocks) > 0 && currentAssistantMessage != nil { + if currentAssistantMessage.Content == nil { + currentAssistantMessage.Content = NewBlocksContent(pendingReasoningContentBlocks) + } else if currentAssistantMessage.Content.BlocksContent != nil { + currentAssistantMessage.Content.BlocksContent = append(currentAssistantMessage.Content.BlocksContent, pendingReasoningContentBlocks...) + } + cohereMessages = append(cohereMessages, *currentAssistantMessage) + pendingReasoningContentBlocks = nil + currentAssistantMessage = nil } - if itemID != "" { - doneItem.ID = &itemID + + // Handle function calls from Responses + assistantMsg := convertBifrostFunctionCallToCohereMessage(&msg) + if assistantMsg != nil { + cohereMessages = append(cohereMessages, *assistantMsg) } - responses = append(responses, &schemas.BifrostResponsesStreamResponse{ - Type: schemas.ResponsesStreamResponseTypeOutputItemDone, - SequenceNumber: sequenceNumber + len(responses), - OutputIndex: schemas.Ptr(outputIndex), - ContentIndex: chunk.Index, - Item: doneItem, - }) - return responses, nil, false + case schemas.ResponsesMessageTypeFunctionCallOutput: + // Handle function call outputs + toolMsg := convertBifrostFunctionCallOutputToCohereMessage(&msg) + if toolMsg != nil { + cohereMessages = append(cohereMessages, *toolMsg) + } } - return nil, nil, false - case StreamEventCitationStart: - if chunk.Index != nil && chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Citations != nil { - // Citation start - create annotation for the citation - citation := chunk.Delta.Message.Citations.CohereStreamCitationObject + } - // Map Cohere citation to ResponsesOutputMessageContentTextAnnotation - annotation := &schemas.ResponsesOutputMessageContentTextAnnotation{ - Type: "file_citation", // Default to file_citation - StartIndex: schemas.Ptr(citation.Start), - EndIndex: schemas.Ptr(citation.End), - } + // Flush any remaining pending reasoning blocks + if len(pendingReasoningContentBlocks) > 0 && currentAssistantMessage != nil { + if currentAssistantMessage.Content == nil { + currentAssistantMessage.Content = NewBlocksContent(pendingReasoningContentBlocks) + } else if currentAssistantMessage.Content.BlocksContent != nil { + currentAssistantMessage.Content.BlocksContent = append(currentAssistantMessage.Content.BlocksContent, pendingReasoningContentBlocks...) + } + cohereMessages = append(cohereMessages, *currentAssistantMessage) + } else if currentAssistantMessage != nil { + cohereMessages = append(cohereMessages, *currentAssistantMessage) + } - // Set annotation type and metadata - if len(citation.Sources) > 0 { - source := citation.Sources[0] + // Prepend system messages if any + if len(systemContent) > 0 { + systemMsg := CohereMessage{ + Role: "system", + Content: NewStringContent(strings.Join(systemContent, "\n")), + } + cohereMessages = append([]CohereMessage{systemMsg}, cohereMessages...) + } + + return cohereMessages +} + +// ConvertCohereMessagesToBifrostMessages converts an array of Cohere messages to Bifrost ResponsesMessage format +// This is the main conversion method from Cohere to Bifrost - handles all message types and content blocks +func ConvertCohereMessagesToBifrostMessages(cohereMessages []CohereMessage, isOutputMessage bool) []schemas.ResponsesMessage { + var bifrostMessages []schemas.ResponsesMessage + + for _, msg := range cohereMessages { + convertedMessages := convertSingleCohereMessageToBifrostMessages(&msg, isOutputMessage) + bifrostMessages = append(bifrostMessages, convertedMessages...) + } + + return bifrostMessages +} + +// convertBifrostToolChoiceToCohere converts schemas.ToolChoice to CohereToolChoice +func convertBifrostToolChoiceToCohereToolChoice(toolChoice schemas.ResponsesToolChoice) *CohereToolChoice { + toolChoiceString := toolChoice.ResponsesToolChoiceStr + + if toolChoiceString != nil { + switch *toolChoiceString { + case "none": + choice := ToolChoiceNone + return &choice + case "required", "function": + choice := ToolChoiceRequired + return &choice + case "auto": + choice := ToolChoiceAuto + return &choice + default: + choice := ToolChoiceRequired + return &choice + } + } + + return nil +} + +// Helper functions for converting individual Cohere message types - if source.ID != nil { - annotation.FileID = source.ID - } +// convertBifrostMessageToCohereSystemContent converts a Bifrost system message to Cohere system content +func convertBifrostMessageToCohereSystemContent(msg *schemas.ResponsesMessage) []string { + var systemContent []string - if source.Document != nil { - if title, ok := (*source.Document)["title"].(string); ok { - annotation.Title = &title - } - if id, ok := (*source.Document)["id"].(string); ok && annotation.FileID == nil { - annotation.FileID = &id - } - if snippet, ok := (*source.Document)["snippet"].(string); ok { - annotation.Text = &snippet - } - if url, ok := (*source.Document)["url"].(string); ok { - annotation.URL = &url - } + if msg.Content != nil { + if msg.Content.ContentStr != nil { + systemContent = append(systemContent, *msg.Content.ContentStr) + } else if msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + if block.Text != nil { + systemContent = append(systemContent, *block.Text) } } + } + } - // Use output_index based on content index for citations (they're part of the text item) - outputIndex := 0 - if citation.ContentIndex >= 0 { - contentIndexPtr := &citation.ContentIndex - outputIndex = state.getOrCreateOutputIndex(contentIndexPtr) - } + return systemContent +} - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeOutputTextAnnotationAdded, - SequenceNumber: sequenceNumber, - ContentIndex: schemas.Ptr(citation.ContentIndex), - Annotation: annotation, - OutputIndex: schemas.Ptr(outputIndex), - AnnotationIndex: chunk.Index, - }}, nil, false +// convertBifrostMessageToCohereMessage converts a regular Bifrost message to Cohere message +func convertBifrostMessageToCohereMessage(msg *schemas.ResponsesMessage) *CohereMessage { + role := "user" + if msg.Role != nil { + role = string(*msg.Role) + } + + cohereMsg := CohereMessage{ + Role: role, + } + + // Convert content - only if Content is not nil + if msg.Content != nil { + if msg.Content.ContentStr != nil { + cohereMsg.Content = NewStringContent(*msg.Content.ContentStr) + } else if msg.Content.ContentBlocks != nil { + contentBlocks := convertResponsesMessageContentBlocksToCohere(msg.Content.ContentBlocks) + cohereMsg.Content = NewBlocksContent(contentBlocks) } - return nil, nil, false - case StreamEventCitationEnd: - if chunk.Index != nil { - // Citation end - indicate annotation is complete - outputIndex := 0 - if chunk.Index != nil { - outputIndex = state.getOrCreateOutputIndex(chunk.Index) + } + + return &cohereMsg +} + +// convertBifrostReasoningToCohereThinking converts a Bifrost reasoning message to Cohere thinking blocks +func convertBifrostReasoningToCohereThinking(msg *schemas.ResponsesMessage) []CohereContentBlock { + var thinkingBlocks []CohereContentBlock + + if msg.Content != nil && msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + if block.Type == schemas.ResponsesOutputMessageContentTypeReasoning && block.Text != nil { + thinkingBlock := CohereContentBlock{ + Type: CohereContentBlockTypeThinking, + Thinking: block.Text, + } + thinkingBlocks = append(thinkingBlocks, thinkingBlock) } - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeOutputTextAnnotationDone, - SequenceNumber: sequenceNumber, - ContentIndex: chunk.Index, - OutputIndex: schemas.Ptr(outputIndex), - AnnotationIndex: chunk.Index, - }}, nil, false } - return nil, nil, false - case StreamEventMessageEnd: - // Message end - emit response.completed (OpenAI-style) - response := &schemas.BifrostResponsesResponse{ - CreatedAt: state.CreatedAt, + } else if msg.ResponsesReasoning != nil { + if msg.ResponsesReasoning.Summary != nil { + for _, reasoningContent := range msg.ResponsesReasoning.Summary { + thinkingBlock := CohereContentBlock{ + Type: CohereContentBlockTypeThinking, + Thinking: &reasoningContent.Text, + } + thinkingBlocks = append(thinkingBlocks, thinkingBlock) + } + } else if msg.ResponsesReasoning.EncryptedContent != nil { + // Cohere doesn't have a direct equivalent to encrypted content, + // so we'll store it as a regular thinking block with a special marker + encryptedText := fmt.Sprintf("[ENCRYPTED_REASONING: %s]", *msg.ResponsesReasoning.EncryptedContent) + thinkingBlock := CohereContentBlock{ + Type: CohereContentBlockTypeThinking, + Thinking: &encryptedText, + } + thinkingBlocks = append(thinkingBlocks, thinkingBlock) } - if state.MessageID != nil { - response.ID = state.MessageID + } + + return thinkingBlocks +} + +// convertBifrostFunctionCallToCohereMessage converts a Bifrost function call to Cohere message +func convertBifrostFunctionCallToCohereMessage(msg *schemas.ResponsesMessage) *CohereMessage { + assistantMsg := CohereMessage{ + Role: "assistant", + } + + // Extract function call details + var cohereToolCalls []CohereToolCall + toolCall := CohereToolCall{ + Type: "function", + Function: &CohereFunction{}, + } + + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { + toolCall.ID = msg.CallID + } + + // Get function details from AssistantMessage + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Arguments != nil { + toolCall.Function.Arguments = *msg.ResponsesToolMessage.Arguments + } + + // Get name from ToolMessage if available + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { + toolCall.Function.Name = msg.ResponsesToolMessage.Name + } + + cohereToolCalls = append(cohereToolCalls, toolCall) + + if len(cohereToolCalls) > 0 { + assistantMsg.ToolCalls = cohereToolCalls + } + + return &assistantMsg +} + +// convertBifrostFunctionCallOutputToCohereMessage converts a Bifrost function call output to Cohere message +func convertBifrostFunctionCallOutputToCohereMessage(msg *schemas.ResponsesMessage) *CohereMessage { + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { + toolMsg := CohereMessage{ + Role: "tool", } - if state.Model != nil { - response.Model = *state.Model + + // Extract content from ResponsesFunctionToolCallOutput if Content is not set + // This is needed for OpenAI Responses API which uses an "output" field + content := msg.Content + if content == nil && msg.ResponsesToolMessage.Output != nil { + content = &schemas.ResponsesMessageContent{} + if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { + content.ContentStr = msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr + } else if msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { + content.ContentBlocks = msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks + } } - if chunk.Delta != nil { - if chunk.Delta.Usage != nil { - usage := &schemas.ResponsesResponseUsage{} + // Convert content - only if Content is not nil + if content != nil { + if content.ContentStr != nil { + toolMsg.Content = NewStringContent(*content.ContentStr) + } else if content.ContentBlocks != nil { + contentBlocks := convertResponsesMessageContentBlocksToCohere(content.ContentBlocks) + toolMsg.Content = NewBlocksContent(contentBlocks) + } + } - if chunk.Delta.Usage.Tokens != nil { - if chunk.Delta.Usage.Tokens.InputTokens != nil { - usage.InputTokens = *chunk.Delta.Usage.Tokens.InputTokens - } - if chunk.Delta.Usage.Tokens.OutputTokens != nil { - usage.OutputTokens = *chunk.Delta.Usage.Tokens.OutputTokens - } - usage.TotalTokens = usage.InputTokens + usage.OutputTokens - } + toolMsg.ToolCallID = msg.ResponsesToolMessage.CallID - if chunk.Delta.Usage.CachedTokens != nil { - usage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ - CachedTokens: *chunk.Delta.Usage.CachedTokens, + return &toolMsg + } + return nil +} + +// convertSingleCohereMessageToBifrostMessages converts a single Cohere message to Bifrost messages +func convertSingleCohereMessageToBifrostMessages(cohereMsg *CohereMessage, isOutputMessage bool) []schemas.ResponsesMessage { + var outputMessages []schemas.ResponsesMessage + var reasoningContentBlocks []schemas.ResponsesMessageContentBlock + + // Handle text content first + if cohereMsg.Content != nil { + var content schemas.ResponsesMessageContent + var contentBlocks []schemas.ResponsesMessageContentBlock + + if cohereMsg.Content.StringContent != nil { + // Determine content block type based on message role and output flag + blockType := schemas.ResponsesInputMessageContentBlockTypeText + if isOutputMessage || cohereMsg.Role == "assistant" { + blockType = schemas.ResponsesOutputMessageContentTypeText + } + + contentBlocks = append(contentBlocks, schemas.ResponsesMessageContentBlock{ + Type: blockType, + Text: cohereMsg.Content.StringContent, + }) + } else if cohereMsg.Content.BlocksContent != nil { + // Convert content blocks and separate reasoning blocks + for _, block := range cohereMsg.Content.BlocksContent { + if block.Type == CohereContentBlockTypeThinking { + // Collect reasoning blocks to create a single reasoning message + reasoningContentBlocks = append(reasoningContentBlocks, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: block.Thinking, + }) + } else { + converted := convertCohereContentBlockToBifrost(block) + if converted.Type != "" { + contentBlocks = append(contentBlocks, converted) } } - response.Usage = usage } } - return []*schemas.BifrostResponsesStreamResponse{{ - Type: schemas.ResponsesStreamResponseTypeCompleted, - SequenceNumber: sequenceNumber, - Response: response, - }}, nil, true - case StreamEventDebug: - return nil, nil, false + content.ContentBlocks = contentBlocks + + // Create message output if we have content blocks + if len(contentBlocks) > 0 { + var role schemas.ResponsesMessageRoleType + switch cohereMsg.Role { + case "user": + role = schemas.ResponsesInputMessageRoleUser + case "assistant": + role = schemas.ResponsesInputMessageRoleAssistant + case "system": + role = schemas.ResponsesInputMessageRoleSystem + default: + role = schemas.ResponsesInputMessageRoleUser + } + + outputMsg := schemas.ResponsesMessage{ + Role: &role, + Content: &content, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + } + + if isOutputMessage { + outputMsg.ID = schemas.Ptr("msg_" + fmt.Sprintf("%d", time.Now().UnixNano())) + } + + outputMessages = append(outputMessages, outputMsg) + } } - return nil, nil, false -} -// ConvertResponsesTextFormatToCohere converts Bifrost Responses Text.Format to Cohere's typed format -// Responses format: Text.Format with type "json_schema", "json_object", or "text" -// Cohere format: { type: "json_object", json_schema: {...} } -func convertResponsesTextFormatToCohere(textFormat *schemas.ResponsesTextConfigFormat) *CohereResponseFormat { - if textFormat == nil { - return nil + // Handle reasoning blocks - prepend reasoning message if we collected any + if len(reasoningContentBlocks) > 0 { + reasoningMessage := schemas.ResponsesMessage{ + ID: schemas.Ptr("rs_" + fmt.Sprintf("%d", time.Now().UnixNano())), + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + }, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: reasoningContentBlocks, + }, + } + // Prepend the reasoning message to the start of the messages list + outputMessages = append([]schemas.ResponsesMessage{reasoningMessage}, outputMessages...) } - cohereFormat := &CohereResponseFormat{} + // Handle tool calls + if cohereMsg.ToolCalls != nil { + for _, toolCall := range cohereMsg.ToolCalls { + // Check if Function is nil to avoid nil pointer dereference + if toolCall.Function == nil { + // Skip this tool call if Function is nil + continue + } - // Convert type - switch textFormat.Type { - case "text": - cohereFormat.Type = ResponseFormatTypeText - case "json_object": - cohereFormat.Type = ResponseFormatTypeJSONObject - case "json_schema": - cohereFormat.Type = ResponseFormatTypeJSONObject + // Safely extract function name and arguments + var functionName *string + var functionArguments *string - // If schema is provided, extract it - if textFormat.JSONSchema != nil { - // Build schema map - schema := make(map[string]interface{}) - if textFormat.JSONSchema.Type != nil { - schema["type"] = *textFormat.JSONSchema.Type - } - if textFormat.JSONSchema.Properties != nil { - schema["properties"] = *textFormat.JSONSchema.Properties + if toolCall.Function.Name != nil { + functionName = toolCall.Function.Name + } else { + // Use empty string if Name is nil + functionName = schemas.Ptr("") } - if len(textFormat.JSONSchema.Required) > 0 { - schema["required"] = textFormat.JSONSchema.Required + + // Arguments is a string, not a pointer, so it's safe to access directly + functionArguments = schemas.Ptr(toolCall.Function.Arguments) + + toolCallMsg := schemas.ResponsesMessage{ + ID: toolCall.ID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + Name: functionName, + CallID: toolCall.ID, + Arguments: functionArguments, + }, } - if textFormat.JSONSchema.AdditionalProperties != nil { - schema["additionalProperties"] = *textFormat.JSONSchema.AdditionalProperties + + if isOutputMessage { + role := schemas.ResponsesInputMessageRoleAssistant + toolCallMsg.Role = &role } - var schemaInterface interface{} = schema - cohereFormat.JSONSchema = &schemaInterface + outputMessages = append(outputMessages, toolCallMsg) } - default: - cohereFormat.Type = ResponseFormatTypeJSONObject } - return cohereFormat + return outputMessages +} + +// convertBifrostContentBlocksToCohere converts Bifrost content blocks to Cohere format +func convertResponsesMessageContentBlocksToCohere(blocks []schemas.ResponsesMessageContentBlock) []CohereContentBlock { + var cohereBlocks []CohereContentBlock + + for _, block := range blocks { + switch block.Type { + case schemas.ResponsesInputMessageContentBlockTypeText, schemas.ResponsesOutputMessageContentTypeText: + // Handle both input_text (user messages) and output_text (assistant messages) + if block.Text != nil { + cohereBlocks = append(cohereBlocks, CohereContentBlock{ + Type: CohereContentBlockTypeText, + Text: block.Text, + }) + } + case schemas.ResponsesInputMessageContentBlockTypeImage: + if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil && *block.ResponsesInputMessageContentBlockImage.ImageURL != "" { + cohereBlocks = append(cohereBlocks, CohereContentBlock{ + Type: CohereContentBlockTypeImage, + ImageURL: &CohereImageURL{ + URL: *block.ResponsesInputMessageContentBlockImage.ImageURL, + }, + }) + } + case schemas.ResponsesOutputMessageContentTypeReasoning: + if block.Text != nil { + cohereBlocks = append(cohereBlocks, CohereContentBlock{ + Type: CohereContentBlockTypeThinking, + Thinking: block.Text, + }) + } + } + } + + return cohereBlocks } diff --git a/core/providers/cohere/types.go b/core/providers/cohere/types.go index bb20a23d2..45489fefd 100644 --- a/core/providers/cohere/types.go +++ b/core/providers/cohere/types.go @@ -8,6 +8,9 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +const MinimumReasoningMaxTokens = 1 +const DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default + // ==================== REQUEST TYPES ==================== // CohereChatRequest represents a Cohere chat completion request diff --git a/core/providers/gemini/responses.go b/core/providers/gemini/responses.go index 3982d4022..8c2942c11 100644 --- a/core/providers/gemini/responses.go +++ b/core/providers/gemini/responses.go @@ -135,47 +135,48 @@ func convertGeminiCandidatesToResponsesOutput(candidates []*Candidate) []schemas messages = append(messages, msg) } - case part.FunctionCall != nil: - // Function call message - // Convert Args to JSON string if it's not already a string - argumentsStr := "" - if part.FunctionCall.Args != nil { - if argsBytes, err := json.Marshal(part.FunctionCall.Args); err == nil { - argumentsStr = string(argsBytes) + case part.FunctionCall != nil: + // Function call message + // Convert Args to JSON string if it's not already a string + argumentsStr := "" + if part.FunctionCall.Args != nil { + if argsBytes, err := json.Marshal(part.FunctionCall.Args); err == nil { + argumentsStr = string(argsBytes) + } } - } - // Create copies of the values to avoid range loop variable capture - functionCallID := part.FunctionCall.ID - functionCallName := part.FunctionCall.Name - - toolMsg := &schemas.ResponsesToolMessage{ - CallID: &functionCallID, - Name: &functionCallName, - Arguments: &argumentsStr, - } - - msg := schemas.ResponsesMessage{ - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Content: &schemas.ResponsesMessageContent{}, - Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), - ResponsesToolMessage: toolMsg, - } - messages = append(messages, msg) - - // Preserve thought signature if present (required for Gemini 3 Pro) - // Store it in a separate ResponsesReasoning message for better scalability - if len(part.ThoughtSignature) > 0 { - thoughtSig := string(part.ThoughtSignature) - reasoningMsg := schemas.ResponsesMessage{ - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), - Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), - ResponsesReasoning: &schemas.ResponsesReasoning{ - EncryptedContent: &thoughtSig, - }, - } - messages = append(messages, reasoningMsg) - } + // Create copies of the values to avoid range loop variable capture + functionCallID := part.FunctionCall.ID + functionCallName := part.FunctionCall.Name + + toolMsg := &schemas.ResponsesToolMessage{ + CallID: &functionCallID, + Name: &functionCallName, + Arguments: &argumentsStr, + } + + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{}, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + ResponsesToolMessage: toolMsg, + } + messages = append(messages, msg) + + // Preserve thought signature if present (required for Gemini 3 Pro) + // Store it in a separate ResponsesReasoning message for better scalability + if len(part.ThoughtSignature) > 0 { + thoughtSig := string(part.ThoughtSignature) + reasoningMsg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + EncryptedContent: &thoughtSig, + }, + } + messages = append(messages, reasoningMsg) + } case part.FunctionResponse != nil: // Function response message @@ -605,27 +606,27 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag } } - part := &Part{ - FunctionCall: &FunctionCall{ - Name: *msg.ResponsesToolMessage.Name, - Args: argsMap, - }, - } - if msg.ResponsesToolMessage.CallID != nil { - part.FunctionCall.ID = *msg.ResponsesToolMessage.CallID - } + part := &Part{ + FunctionCall: &FunctionCall{ + Name: *msg.ResponsesToolMessage.Name, + Args: argsMap, + }, + } + if msg.ResponsesToolMessage.CallID != nil { + part.FunctionCall.ID = *msg.ResponsesToolMessage.CallID + } - // Preserve thought signature from ResponsesReasoning message (required for Gemini 3 Pro) - // Look ahead to see if the next message is a reasoning message with encrypted content - if i+1 < len(messages) { - nextMsg := messages[i+1] - if nextMsg.Type != nil && *nextMsg.Type == schemas.ResponsesMessageTypeReasoning && - nextMsg.ResponsesReasoning != nil && nextMsg.ResponsesReasoning.EncryptedContent != nil { - part.ThoughtSignature = []byte(*nextMsg.ResponsesReasoning.EncryptedContent) + // Preserve thought signature from ResponsesReasoning message (required for Gemini 3 Pro) + // Look ahead to see if the next message is a reasoning message with encrypted content + if i+1 < len(messages) { + nextMsg := messages[i+1] + if nextMsg.Type != nil && *nextMsg.Type == schemas.ResponsesMessageTypeReasoning && + nextMsg.ResponsesReasoning != nil && nextMsg.ResponsesReasoning.EncryptedContent != nil { + part.ThoughtSignature = []byte(*nextMsg.ResponsesReasoning.EncryptedContent) + } } - } - content.Parts = append(content.Parts, part) + content.Parts = append(content.Parts, part) } case schemas.ResponsesMessageTypeFunctionCallOutput: // Convert function response to Gemini FunctionResponse diff --git a/core/providers/groq/groq_test.go b/core/providers/groq/groq_test.go index e57b7b517..f2d9acb5a 100644 --- a/core/providers/groq/groq_test.go +++ b/core/providers/groq/groq_test.go @@ -34,6 +34,7 @@ func TestGroq(t *testing.T) { {Provider: schemas.Groq, Model: "openai/gpt-oss-20b"}, }, EmbeddingModel: "", // Groq doesn't support embedding + ReasoningModel: "openai/gpt-oss-120b", Scenarios: testutil.TestScenarios{ TextCompletion: true, // Supported via chat completion conversion TextCompletionStream: true, // Supported via chat completion streaming conversion @@ -51,6 +52,7 @@ func TestGroq(t *testing.T) { CompleteEnd2End: true, Embedding: false, ListModels: true, + Reasoning: true, }, } diff --git a/core/providers/mistral/mistral_test.go b/core/providers/mistral/mistral_test.go index f1f9188db..275aeb23b 100644 --- a/core/providers/mistral/mistral_test.go +++ b/core/providers/mistral/mistral_test.go @@ -46,6 +46,7 @@ func TestMistral(t *testing.T) { CompleteEnd2End: true, Embedding: true, ListModels: false, + Reasoning: false, // Not supported right now because we are not using native mistral converters }, } diff --git a/core/providers/openai/chat.go b/core/providers/openai/chat.go index abfd07188..92f37748c 100644 --- a/core/providers/openai/chat.go +++ b/core/providers/openai/chat.go @@ -30,6 +30,11 @@ func ToOpenAIChatRequest(bifrostReq *schemas.BifrostChatRequest) *OpenAIChatRequ if bifrostReq.Params != nil { openaiReq.ChatParameters = *bifrostReq.Params + if openaiReq.ChatParameters.MaxCompletionTokens != nil && *openaiReq.ChatParameters.MaxCompletionTokens < MinMaxCompletionTokens { + openaiReq.ChatParameters.MaxCompletionTokens = schemas.Ptr(MinMaxCompletionTokens) + } + // Drop user field if it exceeds OpenAI's 64 character limit + openaiReq.ChatParameters.User = SanitizeUserField(openaiReq.ChatParameters.User) } switch bifrostReq.Provider { diff --git a/core/providers/openai/responses.go b/core/providers/openai/responses.go index c4246f5ea..7c074ac47 100644 --- a/core/providers/openai/responses.go +++ b/core/providers/openai/responses.go @@ -1,6 +1,10 @@ package openai -import "github.com/maximhq/bifrost/core/schemas" +import ( + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) // ToBifrostResponsesRequest converts an OpenAI responses request to Bifrost format func (request *OpenAIResponsesRequest) ToBifrostResponsesRequest() *schemas.BifrostResponsesRequest { @@ -34,20 +38,72 @@ func ToOpenAIResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *Open if bifrostReq == nil || bifrostReq.Input == nil { return nil } - // Preparing final input - input := OpenAIResponsesRequestInput{ - OpenAIResponsesRequestInputArray: bifrostReq.Input, + + var messages []schemas.ResponsesMessage + // OpenAI models (except for gpt-oss) do not support reasoning content blocks, so we need to convert them to summaries, if there are any + messages = make([]schemas.ResponsesMessage, 0, len(bifrostReq.Input)) + for _, message := range bifrostReq.Input { + if message.ResponsesReasoning != nil { + // According to OpenAI's Responses API format specification, for non-gpt-oss models, a message + // with ResponsesReasoning != nil and non-empty Content.ContentBlocks but empty Summary and + // nil EncryptedContent represents a reasoning-only message that should be skipped, as these + // models do not support reasoning content blocks in the output. This constraint ensures + // compatibility with OpenAI's intended responses format behavior where reasoning-only messages + // without summaries are not included in the request payload for non-gpt-oss models. + if len(message.ResponsesReasoning.Summary) == 0 && + message.Content != nil && + len(message.Content.ContentBlocks) > 0 && + !strings.Contains(bifrostReq.Model, "gpt-oss") && + message.ResponsesReasoning.EncryptedContent == nil { + continue + } + + // If the message has summaries but no content blocks and the model is gpt-oss, then convert the summaries to content blocks + if len(message.ResponsesReasoning.Summary) > 0 && + strings.Contains(bifrostReq.Model, "gpt-oss") && + message.Content == nil { + var newMessage schemas.ResponsesMessage + newMessage.ID = message.ID + newMessage.Type = message.Type + newMessage.Status = message.Status + newMessage.Role = message.Role + + // Convert summaries to content blocks + contentBlocks := make([]schemas.ResponsesMessageContentBlock, 0, len(message.ResponsesReasoning.Summary)) + for _, summary := range message.ResponsesReasoning.Summary { + contentBlocks = append(contentBlocks, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: schemas.Ptr(summary.Text), + }) + } + newMessage.Content = &schemas.ResponsesMessageContent{ + ContentBlocks: contentBlocks, + } + messages = append(messages, newMessage) + } else { + messages = append(messages, message) + } + } else { + messages = append(messages, message) + } } // Updating params params := bifrostReq.Params // Create the responses request with properly mapped parameters req := &OpenAIResponsesRequest{ Model: bifrostReq.Model, - Input: input, + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: messages, + }, } if params != nil { req.ResponsesParameters = *params + if req.ResponsesParameters.MaxOutputTokens != nil && *req.ResponsesParameters.MaxOutputTokens < MinMaxCompletionTokens { + req.ResponsesParameters.MaxOutputTokens = schemas.Ptr(MinMaxCompletionTokens) + } + // Drop user field if it exceeds OpenAI's 64 character limit + req.ResponsesParameters.User = SanitizeUserField(req.ResponsesParameters.User) // Filter out tools that OpenAI doesn't support req.filterUnsupportedTools() } diff --git a/core/providers/openai/responses_marshal_test.go b/core/providers/openai/responses_marshal_test.go new file mode 100644 index 000000000..e7f08f291 --- /dev/null +++ b/core/providers/openai/responses_marshal_test.go @@ -0,0 +1,481 @@ +package openai + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +func TestOpenAIResponsesRequest_MarshalJSON_ReasoningMaxTokensAbsent(t *testing.T) { + tests := []struct { + name string + request *OpenAIResponsesRequest + description string + }{ + { + name: "reasoning with MaxTokens set should omit max_tokens from output", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputStr: schemas.Ptr("test input"), + }, + ResponsesParameters: schemas.ResponsesParameters{ + Reasoning: &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr("high"), + MaxTokens: schemas.Ptr(1000), + Summary: schemas.Ptr("detailed"), + }, + }, + }, + description: "When Reasoning.MaxTokens is set, it should be absent from JSON output", + }, + { + name: "reasoning with all fields set should omit only max_tokens", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputStr: schemas.Ptr("test"), + }, + ResponsesParameters: schemas.ResponsesParameters{ + Reasoning: &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr("medium"), + GenerateSummary: schemas.Ptr("auto"), + Summary: schemas.Ptr("concise"), + MaxTokens: schemas.Ptr(500), + }, + }, + }, + description: "All reasoning fields except MaxTokens should be present in output", + }, + { + name: "reasoning with nil MaxTokens should not include max_tokens", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputStr: schemas.Ptr("test"), + }, + ResponsesParameters: schemas.ResponsesParameters{ + Reasoning: &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr("low"), + MaxTokens: nil, + }, + }, + }, + description: "When Reasoning.MaxTokens is nil, max_tokens should not appear in output", + }, + { + name: "nil reasoning should not include reasoning field", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputStr: schemas.Ptr("test"), + }, + ResponsesParameters: schemas.ResponsesParameters{ + Reasoning: nil, + }, + }, + description: "When Reasoning is nil, reasoning field should not appear in output", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jsonBytes, err := tt.request.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal JSON: %v", err) + } + + // Parse the JSON to check structure + var jsonMap map[string]interface{} + if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil { + t.Fatalf("Failed to unmarshal marshaled JSON: %v", err) + } + + // Check that reasoning.max_tokens is absent + if reasoning, ok := jsonMap["reasoning"].(map[string]interface{}); ok { + if maxTokens, exists := reasoning["max_tokens"]; exists { + t.Errorf("%s: reasoning.max_tokens should be absent from JSON output, but found: %v", tt.description, maxTokens) + } + + // Verify other reasoning fields are present when they should be + if tt.request.Reasoning != nil { + if tt.request.Reasoning.Effort != nil { + if _, exists := reasoning["effort"]; !exists { + t.Error("reasoning.effort should be present in output") + } + } + if tt.request.Reasoning.Summary != nil { + if _, exists := reasoning["summary"]; !exists { + t.Error("reasoning.summary should be present in output") + } + } + if tt.request.Reasoning.GenerateSummary != nil { + if _, exists := reasoning["generate_summary"]; !exists { + t.Error("reasoning.generate_summary should be present in output") + } + } + } + } else if tt.request.Reasoning != nil { + // If reasoning is set, it should appear in JSON (unless all fields are nil/omitted) + if tt.request.Reasoning.Effort != nil || tt.request.Reasoning.Summary != nil || tt.request.Reasoning.GenerateSummary != nil { + t.Error("reasoning field should be present in JSON when Reasoning is set with non-nil fields") + } + } + }) + } +} + +func TestOpenAIResponsesRequest_MarshalJSON_InputStringForm(t *testing.T) { + tests := []struct { + name string + request *OpenAIResponsesRequest + expected string + description string + }{ + { + name: "input as string is correctly marshaled", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputStr: schemas.Ptr("Hello, world!"), + }, + }, + expected: "Hello, world!", + description: "Input field should be marshaled as a string when OpenAIResponsesRequestInputStr is set", + }, + { + name: "input as empty string is correctly marshaled", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputStr: schemas.Ptr(""), + }, + }, + expected: "", + description: "Input field should be marshaled as empty string when set to empty string", + }, + { + name: "input as string with special characters", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputStr: schemas.Ptr(`{"key": "value"}`), + }, + }, + expected: `{"key": "value"}`, + description: "Input field should correctly marshal strings with special characters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jsonBytes, err := tt.request.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal JSON: %v", err) + } + + // Parse the JSON to check input field + var jsonMap map[string]interface{} + if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil { + t.Fatalf("Failed to unmarshal marshaled JSON: %v", err) + } + + // Check that input is a string + inputValue, exists := jsonMap["input"] + if !exists { + t.Fatalf("%s: input field should be present in JSON", tt.description) + } + + inputStr, ok := inputValue.(string) + if !ok { + t.Errorf("%s: input field should be a string, got type %T", tt.description, inputValue) + } + + if inputStr != tt.expected { + t.Errorf("%s: expected input to be %q, got %q", tt.description, tt.expected, inputStr) + } + }) + } +} + +func TestOpenAIResponsesRequest_MarshalJSON_InputArrayForm(t *testing.T) { + tests := []struct { + name string + request *OpenAIResponsesRequest + validate func(t *testing.T, inputValue interface{}) + description string + }{ + { + name: "input as array is correctly marshaled", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{ + { + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + }, + }, + validate: func(t *testing.T, inputValue interface{}) { + inputArray, ok := inputValue.([]interface{}) + if !ok { + t.Fatalf("Expected input to be an array, got type %T", inputValue) + } + if len(inputArray) != 1 { + t.Errorf("Expected 1 message in array, got %d", len(inputArray)) + } + }, + description: "Input field should be marshaled as an array when OpenAIResponsesRequestInputArray is set", + }, + { + name: "input as empty array is correctly marshaled", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{}, + }, + }, + validate: func(t *testing.T, inputValue interface{}) { + inputArray, ok := inputValue.([]interface{}) + if !ok { + t.Fatalf("Expected input to be an array, got type %T", inputValue) + } + if len(inputArray) != 0 { + t.Errorf("Expected empty array, got %d elements", len(inputArray)) + } + }, + description: "Input field should be marshaled as empty array when set to empty array", + }, + { + name: "input as array with multiple messages", + request: &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{ + { + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("You are a helpful assistant."), + }, + }, + { + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("What is 2+2?"), + }, + }, + }, + }, + }, + validate: func(t *testing.T, inputValue interface{}) { + inputArray, ok := inputValue.([]interface{}) + if !ok { + t.Fatalf("Expected input to be an array, got type %T", inputValue) + } + if len(inputArray) != 2 { + t.Errorf("Expected 2 messages in array, got %d", len(inputArray)) + } + }, + description: "Input field should correctly marshal arrays with multiple messages", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jsonBytes, err := tt.request.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal JSON: %v", err) + } + + // Parse the JSON to check input field + var jsonMap map[string]interface{} + if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil { + t.Fatalf("Failed to unmarshal marshaled JSON: %v", err) + } + + // Check that input is present + inputValue, exists := jsonMap["input"] + if !exists { + t.Fatalf("%s: input field should be present in JSON", tt.description) + } + + // Validate using the provided function + tt.validate(t, inputValue) + }) + } +} + +func TestOpenAIResponsesRequest_MarshalJSON_FieldShadowingBehavior(t *testing.T) { + // This test verifies that the field shadowing pattern works correctly + // by ensuring that the aux struct properly shadows Input and Reasoning fields + t.Run("field shadowing preserves other fields", func(t *testing.T) { + request := &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputStr: schemas.Ptr("test input"), + }, + ResponsesParameters: schemas.ResponsesParameters{ + MaxOutputTokens: schemas.Ptr(100), + Temperature: schemas.Ptr(0.7), + Reasoning: &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr("high"), + MaxTokens: schemas.Ptr(500), // This should be omitted + Summary: schemas.Ptr("detailed"), + }, + }, + Stream: schemas.Ptr(true), + Fallbacks: []string{"fallback1", "fallback2"}, + } + + jsonBytes, err := request.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal JSON: %v", err) + } + + var jsonMap map[string]interface{} + if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil { + t.Fatalf("Failed to unmarshal marshaled JSON: %v", err) + } + + // Verify base fields are present + if jsonMap["model"] != "gpt-4o" { + t.Errorf("Expected model to be 'gpt-4o', got %v", jsonMap["model"]) + } + + if jsonMap["stream"] != true { + t.Errorf("Expected stream to be true, got %v", jsonMap["stream"]) + } + + fallbacks, ok := jsonMap["fallbacks"].([]interface{}) + if !ok || len(fallbacks) != 2 { + t.Errorf("Expected fallbacks to have 2 elements, got %v", jsonMap["fallbacks"]) + } + + // Verify ResponsesParameters fields are present + if jsonMap["max_output_tokens"] != float64(100) { + t.Errorf("Expected max_output_tokens to be 100, got %v", jsonMap["max_output_tokens"]) + } + + if jsonMap["temperature"] != 0.7 { + t.Errorf("Expected temperature to be 0.7, got %v", jsonMap["temperature"]) + } + + // Verify reasoning.max_tokens is absent + if reasoning, ok := jsonMap["reasoning"].(map[string]interface{}); ok { + if _, exists := reasoning["max_tokens"]; exists { + t.Error("reasoning.max_tokens should be absent from JSON output") + } + if reasoning["effort"] != "high" { + t.Errorf("Expected reasoning.effort to be 'high', got %v", reasoning["effort"]) + } + if reasoning["summary"] != "detailed" { + t.Errorf("Expected reasoning.summary to be 'detailed', got %v", reasoning["summary"]) + } + } else { + t.Error("reasoning field should be present in JSON") + } + + // Verify input is correctly marshaled + if jsonMap["input"] != "test input" { + t.Errorf("Expected input to be 'test input', got %v", jsonMap["input"]) + } + }) +} + +func TestOpenAIResponsesRequest_MarshalJSON_RoundTrip(t *testing.T) { + // Test that marshaling and unmarshaling preserves all fields except reasoning.max_tokens + t.Run("round trip preserves fields except reasoning.max_tokens", func(t *testing.T) { + original := &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{ + { + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Test message"), + }, + }, + }, + }, + ResponsesParameters: schemas.ResponsesParameters{ + MaxOutputTokens: schemas.Ptr(200), + Temperature: schemas.Ptr(0.8), + Reasoning: &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr("medium"), + MaxTokens: schemas.Ptr(1000), // Should be omitted + Summary: schemas.Ptr("auto"), + }, + }, + Stream: schemas.Ptr(false), + } + + // Marshal + jsonBytes, err := original.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Verify reasoning.max_tokens is absent in the JSON string + jsonStr := string(jsonBytes) + if strings.Contains(jsonStr, `"max_tokens"`) { + // Check if it's inside reasoning object + if strings.Contains(jsonStr, `"reasoning"`) { + // Parse to verify it's not in reasoning + var jsonMap map[string]interface{} + if err := json.Unmarshal(jsonBytes, &jsonMap); err == nil { + if reasoning, ok := jsonMap["reasoning"].(map[string]interface{}); ok { + if _, exists := reasoning["max_tokens"]; exists { + t.Error("reasoning.max_tokens should not be present in marshaled JSON") + } + } + } + } + } + + // Unmarshal back + var unmarshaled OpenAIResponsesRequest + if err := sonic.Unmarshal(jsonBytes, &unmarshaled); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + // Verify fields are preserved + if unmarshaled.Model != original.Model { + t.Errorf("Model not preserved: expected %q, got %q", original.Model, unmarshaled.Model) + } + + if unmarshaled.Stream == nil || *unmarshaled.Stream != *original.Stream { + t.Error("Stream not preserved") + } + + if unmarshaled.MaxOutputTokens == nil || *unmarshaled.MaxOutputTokens != *original.MaxOutputTokens { + t.Error("MaxOutputTokens not preserved") + } + + if unmarshaled.Temperature == nil || *unmarshaled.Temperature != *original.Temperature { + t.Error("Temperature not preserved") + } + + // Verify reasoning fields except MaxTokens + if unmarshaled.Reasoning == nil { + t.Fatal("Reasoning should be present") + } + if unmarshaled.Reasoning.Effort == nil || *unmarshaled.Reasoning.Effort != *original.Reasoning.Effort { + t.Error("Reasoning.Effort not preserved") + } + if unmarshaled.Reasoning.Summary == nil || *unmarshaled.Reasoning.Summary != *original.Reasoning.Summary { + t.Error("Reasoning.Summary not preserved") + } + // MaxTokens should be nil after unmarshaling (since it wasn't in JSON) + if unmarshaled.Reasoning.MaxTokens != nil { + t.Error("Reasoning.MaxTokens should be nil after unmarshaling (was omitted from JSON)") + } + }) +} + diff --git a/core/providers/openai/responses_test.go b/core/providers/openai/responses_test.go new file mode 100644 index 000000000..764bd1c1b --- /dev/null +++ b/core/providers/openai/responses_test.go @@ -0,0 +1,346 @@ +package openai + +import ( + "strings" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestToOpenAIResponsesRequest_ReasoningOnlyMessageSkip(t *testing.T) { + tests := []struct { + name string + model string + message schemas.ResponsesMessage + expectedIncluded bool + description string + }{ + { + name: "reasoning-only message skipped for non-gpt-oss model", + model: "gpt-4o", + message: schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, // empty Summary + EncryptedContent: nil, // nil EncryptedContent + }, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: schemas.Ptr("reasoning text"), + }, + }, // non-empty ContentBlocks + }, + }, + expectedIncluded: false, + description: "Message with ResponsesReasoning != nil, empty Summary, non-empty ContentBlocks, non-gpt-oss model, and nil EncryptedContent should be skipped", + }, + { + name: "message with Summary preserved for non-gpt-oss model", + model: "gpt-4o", + message: schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{ + { + Type: schemas.ResponsesReasoningContentBlockTypeSummaryText, + Text: "summary text", + }, + }, // non-empty Summary + EncryptedContent: nil, + }, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: schemas.Ptr("reasoning text"), + }, + }, + }, + }, + expectedIncluded: true, + description: "Message with non-empty Summary should be preserved even if it has ContentBlocks", + }, + { + name: "message with EncryptedContent preserved for non-gpt-oss model", + model: "gpt-4o", + message: schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, // empty Summary + EncryptedContent: schemas.Ptr("encrypted"), // non-nil EncryptedContent + }, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: schemas.Ptr("reasoning text"), + }, + }, + }, + }, + expectedIncluded: true, + description: "Message with non-nil EncryptedContent should be preserved even if Summary is empty", + }, + { + name: "message with empty ContentBlocks preserved for non-gpt-oss model", + model: "gpt-4o", + message: schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, // empty Summary + EncryptedContent: nil, + }, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{}, // empty ContentBlocks + }, + }, + expectedIncluded: true, + description: "Message with empty ContentBlocks should be preserved", + }, + { + name: "message with nil Content preserved for non-gpt-oss model", + model: "gpt-4o", + message: schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, // empty Summary + EncryptedContent: nil, + }, + Content: nil, // nil Content + }, + expectedIncluded: true, + description: "Message with nil Content should be preserved", + }, + { + name: "reasoning-only message preserved for gpt-oss model", + model: "gpt-oss", + message: schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, // empty Summary + EncryptedContent: nil, + }, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: schemas.Ptr("reasoning text"), + }, + }, + }, + }, + expectedIncluded: true, + description: "Message with reasoning-only content should be preserved for gpt-oss model", + }, + { + name: "message without ResponsesReasoning preserved", + model: "gpt-4o", + message: schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: schemas.Ptr("regular text"), + }, + }, + }, + }, + expectedIncluded: true, + description: "Message without ResponsesReasoning should always be preserved", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bifrostReq := &schemas.BifrostResponsesRequest{ + Model: tt.model, + Input: []schemas.ResponsesMessage{tt.message}, + } + + result := ToOpenAIResponsesRequest(bifrostReq) + + if result == nil { + t.Fatal("ToOpenAIResponsesRequest returned nil") + } + + messageCount := len(result.Input.OpenAIResponsesRequestInputArray) + isIncluded := messageCount > 0 + + if isIncluded != tt.expectedIncluded { + t.Errorf("%s: expected message to be included=%v (messageCount=%d), got included=%v (messageCount=%d)", + tt.description, tt.expectedIncluded, func() int { + if tt.expectedIncluded { + return 1 + } + return 0 + }(), isIncluded, messageCount) + } + + // If message should be included, verify it's actually present + if tt.expectedIncluded && messageCount == 0 { + t.Error("Expected message to be included but result array is empty") + } + + // If message should be excluded, verify it's not present + if !tt.expectedIncluded && messageCount > 0 { + t.Errorf("Expected message to be excluded but found %d message(s) in result", messageCount) + } + }) + } +} + +func TestToOpenAIResponsesRequest_GPTOSS_SummaryToContentBlocks(t *testing.T) { + tests := []struct { + name string + model string + message schemas.ResponsesMessage + expectedBlocks int + expectedBlockText string + description string + }{ + { + name: "gpt-oss converts Summary to ContentBlocks", + model: "gpt-oss", + message: schemas.ResponsesMessage{ + ID: schemas.Ptr("msg-1"), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Status: schemas.Ptr("completed"), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{ + { + Type: schemas.ResponsesReasoningContentBlockTypeSummaryText, + Text: "First summary", + }, + { + Type: schemas.ResponsesReasoningContentBlockTypeSummaryText, + Text: "Second summary", + }, + }, + EncryptedContent: nil, + }, + Content: nil, // No ContentBlocks initially + }, + expectedBlocks: 2, + expectedBlockText: "First summary", + description: "gpt-oss model should convert Summary to ContentBlocks when Content is nil", + }, + { + name: "gpt-oss preserves message when Content already exists", + model: "gpt-oss", + message: schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{ + { + Type: schemas.ResponsesReasoningContentBlockTypeSummaryText, + Text: "summary text", + }, + }, + }, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: schemas.Ptr("existing content"), + }, + }, + }, + }, + expectedBlocks: 1, + expectedBlockText: "existing content", + description: "gpt-oss model should preserve message when Content already exists", + }, + { + name: "gpt-oss variant model converts Summary to ContentBlocks", + model: "provider/gpt-oss-variant", + message: schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{ + { + Type: schemas.ResponsesReasoningContentBlockTypeSummaryText, + Text: "variant summary", + }, + }, + }, + Content: nil, + }, + expectedBlocks: 1, + expectedBlockText: "variant summary", + description: "gpt-oss variant model should also convert Summary to ContentBlocks", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bifrostReq := &schemas.BifrostResponsesRequest{ + Model: tt.model, + Input: []schemas.ResponsesMessage{tt.message}, + } + + result := ToOpenAIResponsesRequest(bifrostReq) + + if result == nil { + t.Fatal("ToOpenAIResponsesRequest returned nil") + } + + if len(result.Input.OpenAIResponsesRequestInputArray) != 1 { + t.Fatalf("Expected 1 message, got %d", len(result.Input.OpenAIResponsesRequestInputArray)) + } + + resultMsg := result.Input.OpenAIResponsesRequestInputArray[0] + + // Check if Summary was converted to ContentBlocks for gpt-oss + if strings.Contains(tt.model, "gpt-oss") && len(tt.message.ResponsesReasoning.Summary) > 0 && tt.message.Content == nil { + if resultMsg.Content == nil { + t.Fatal("Expected Content to be created from Summary") + } + + if len(resultMsg.Content.ContentBlocks) != tt.expectedBlocks { + t.Errorf("Expected %d ContentBlocks, got %d", tt.expectedBlocks, len(resultMsg.Content.ContentBlocks)) + } + + if len(resultMsg.Content.ContentBlocks) > 0 { + firstBlock := resultMsg.Content.ContentBlocks[0] + if firstBlock.Type != schemas.ResponsesOutputMessageContentTypeReasoning { + t.Errorf("Expected ContentBlock type to be reasoning_text, got %s", firstBlock.Type) + } + + if firstBlock.Text == nil || *firstBlock.Text != tt.expectedBlockText { + t.Errorf("Expected first ContentBlock text to be %q, got %q", tt.expectedBlockText, func() string { + if firstBlock.Text == nil { + return "" + } + return *firstBlock.Text + }()) + } + } + + // Verify that original message fields are preserved + if tt.message.ID != nil && (resultMsg.ID == nil || *resultMsg.ID != *tt.message.ID) { + t.Errorf("Expected ID to be preserved") + } + if tt.message.Type != nil && (resultMsg.Type == nil || *resultMsg.Type != *tt.message.Type) { + t.Errorf("Expected Type to be preserved") + } + if tt.message.Status != nil && (resultMsg.Status == nil || *resultMsg.Status != *tt.message.Status) { + t.Errorf("Expected Status to be preserved") + } + if tt.message.Role != nil && (resultMsg.Role == nil || *resultMsg.Role != *tt.message.Role) { + t.Errorf("Expected Role to be preserved") + } + } else { + // For other cases, verify message is preserved as-is + if resultMsg.Content != nil && len(resultMsg.Content.ContentBlocks) > 0 { + if resultMsg.Content.ContentBlocks[0].Text == nil || *resultMsg.Content.ContentBlocks[0].Text != tt.expectedBlockText { + t.Errorf("Expected ContentBlock text to be preserved as %q", tt.expectedBlockText) + } + } + } + }) + } +} diff --git a/core/providers/openai/text.go b/core/providers/openai/text.go index 4663d9ab0..f0154034a 100644 --- a/core/providers/openai/text.go +++ b/core/providers/openai/text.go @@ -16,6 +16,8 @@ func ToOpenAITextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequ } if params != nil { openaiReq.TextCompletionParameters = *params + // Drop user field if it exceeds OpenAI's 64 character limit + openaiReq.TextCompletionParameters.User = SanitizeUserField(openaiReq.TextCompletionParameters.User) } return openaiReq } diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index 352006e82..2cab76756 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -1,12 +1,17 @@ package openai import ( + "encoding/json" "fmt" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" ) +const ( + MinMaxCompletionTokens = 16 +) + // REQUEST TYPES // OpenAITextCompletionRequest represents an OpenAI text completion request @@ -103,6 +108,39 @@ func (r *OpenAIChatRequest) MarshalJSON() ([]byte, error) { return sonic.Marshal(aux) } +// UnmarshalJSON implements custom JSON unmarshalling for OpenAIChatRequest. +// This is needed because ChatParameters has a custom UnmarshalJSON method, +// which would otherwise "hijack" the unmarshalling and ignore the other fields +// (Model, Messages, Stream, MaxTokens, Fallbacks). +func (r *OpenAIChatRequest) UnmarshalJSON(data []byte) error { + // Unmarshal the request-specific fields directly + type baseFields struct { + Model string `json:"model"` + Messages []OpenAIMessage `json:"messages"` + Stream *bool `json:"stream,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` + } + var base baseFields + if err := sonic.Unmarshal(data, &base); err != nil { + return err + } + r.Model = base.Model + r.Messages = base.Messages + r.Stream = base.Stream + r.MaxTokens = base.MaxTokens + r.Fallbacks = base.Fallbacks + + // Unmarshal ChatParameters (which has its own custom unmarshaller) + var params schemas.ChatParameters + if err := sonic.Unmarshal(data, ¶ms); err != nil { + return err + } + r.ChatParameters = params + + return nil +} + // IsStreamingRequested implements the StreamingRequest interface func (r *OpenAIChatRequest) IsStreamingRequested() bool { return r.Stream != nil && *r.Stream @@ -153,6 +191,46 @@ type OpenAIResponsesRequest struct { Fallbacks []string `json:"fallbacks,omitempty"` } +// MarshalJSON implements custom JSON marshalling for OpenAIResponsesRequest. +// It sets parameters.reasoning.max_tokens to nil before marshaling. +func (r *OpenAIResponsesRequest) MarshalJSON() ([]byte, error) { + type Alias OpenAIResponsesRequest + + // Manually marshal Input using its custom MarshalJSON method + inputBytes, err := r.Input.MarshalJSON() + if err != nil { + return nil, err + } + + // Aux struct: + // - Alias embeds all original fields + // - Input shadows the embedded Input field and uses json.RawMessage to preserve custom marshaling + // - Reasoning shadows the embedded ResponsesParameters.Reasoning + // so that we can modify max_tokens before marshaling + aux := struct { + *Alias + // Shadow the embedded "input" field to use custom marshaling + Input json.RawMessage `json:"input"` + // Shadow the embedded "reasoning" field to modify it + Reasoning *schemas.ResponsesParametersReasoning `json:"reasoning,omitempty"` + }{ + Alias: (*Alias)(r), + Input: json.RawMessage(inputBytes), + } + + // Copy reasoning but set MaxTokens to nil + if r.Reasoning != nil { + aux.Reasoning = &schemas.ResponsesParametersReasoning{ + Effort: r.Reasoning.Effort, + GenerateSummary: r.Reasoning.GenerateSummary, + Summary: r.Reasoning.Summary, + MaxTokens: nil, // Always set to nil + } + } + + return sonic.Marshal(aux) +} + // IsStreamingRequested implements the StreamingRequest interface func (r *OpenAIResponsesRequest) IsStreamingRequested() bool { return r.Stream != nil && *r.Stream diff --git a/core/providers/openai/types_test.go b/core/providers/openai/types_test.go new file mode 100644 index 000000000..7e5dc5f07 --- /dev/null +++ b/core/providers/openai/types_test.go @@ -0,0 +1,463 @@ +package openai + +import ( + "testing" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +func TestOpenAIChatRequest_UnmarshalJSON_BaseFieldsPreserved(t *testing.T) { + tests := []struct { + name string + jsonPayload string + validate func(t *testing.T, req *OpenAIChatRequest) + }{ + { + name: "all base fields preserved with ChatParameters", + jsonPayload: `{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "Hello, world!" + } + ], + "stream": true, + "max_tokens": 100, + "fallbacks": ["gpt-3.5-turbo"], + "temperature": 0.7, + "top_p": 0.9 + }`, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // Assert base fields are preserved + if req.Model != "gpt-4o" { + t.Errorf("Expected Model to be 'gpt-4o', got %q", req.Model) + } + + if len(req.Messages) != 1 { + t.Fatalf("Expected 1 message, got %d", len(req.Messages)) + } + if req.Messages[0].Role != schemas.ChatMessageRoleUser { + t.Errorf("Expected message role to be 'user', got %q", req.Messages[0].Role) + } + if req.Messages[0].Content == nil || req.Messages[0].Content.ContentStr == nil { + t.Fatal("Expected message content to be set") + } + if *req.Messages[0].Content.ContentStr != "Hello, world!" { + t.Errorf("Expected message content to be 'Hello, world!', got %q", *req.Messages[0].Content.ContentStr) + } + + if req.Stream == nil || !*req.Stream { + t.Error("Expected Stream to be true") + } + + if req.MaxTokens == nil || *req.MaxTokens != 100 { + t.Errorf("Expected MaxTokens to be 100, got %v", req.MaxTokens) + } + + if len(req.Fallbacks) != 1 || req.Fallbacks[0] != "gpt-3.5-turbo" { + t.Errorf("Expected Fallbacks to be ['gpt-3.5-turbo'], got %v", req.Fallbacks) + } + + // Assert ChatParameters fields are populated + if req.Temperature == nil || *req.Temperature != 0.7 { + t.Errorf("Expected Temperature to be 0.7, got %v", req.Temperature) + } + + if req.TopP == nil || *req.TopP != 0.9 { + t.Errorf("Expected TopP to be 0.9, got %v", req.TopP) + } + }, + }, + { + name: "base fields with multiple ChatParameters fields", + jsonPayload: `{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is 2+2?" + } + ], + "stream": false, + "max_tokens": 500, + "fallbacks": ["gpt-4o", "gpt-4"], + "temperature": 0.5, + "top_p": 0.95, + "frequency_penalty": 0.2, + "presence_penalty": 0.3, + "seed": 42, + "stop": ["STOP", "END"] + }`, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // Assert base fields + if req.Model != "gpt-3.5-turbo" { + t.Errorf("Expected Model to be 'gpt-3.5-turbo', got %q", req.Model) + } + + if len(req.Messages) != 2 { + t.Fatalf("Expected 2 messages, got %d", len(req.Messages)) + } + + if req.Stream == nil || *req.Stream { + t.Error("Expected Stream to be false") + } + + if req.MaxTokens == nil || *req.MaxTokens != 500 { + t.Errorf("Expected MaxTokens to be 500, got %v", req.MaxTokens) + } + + if len(req.Fallbacks) != 2 { + t.Errorf("Expected 2 fallbacks, got %d", len(req.Fallbacks)) + } + + // Assert multiple ChatParameters fields + if req.Temperature == nil || *req.Temperature != 0.5 { + t.Errorf("Expected Temperature to be 0.5, got %v", req.Temperature) + } + + if req.TopP == nil || *req.TopP != 0.95 { + t.Errorf("Expected TopP to be 0.95, got %v", req.TopP) + } + + if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.2 { + t.Errorf("Expected FrequencyPenalty to be 0.2, got %v", req.FrequencyPenalty) + } + + if req.PresencePenalty == nil || *req.PresencePenalty != 0.3 { + t.Errorf("Expected PresencePenalty to be 0.3, got %v", req.PresencePenalty) + } + + if req.Seed == nil || *req.Seed != 42 { + t.Errorf("Expected Seed to be 42, got %v", req.Seed) + } + + if len(req.Stop) != 2 { + t.Errorf("Expected Stop to have 2 elements, got %d", len(req.Stop)) + } + }, + }, + { + name: "base fields with optional fields omitted", + jsonPayload: `{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "Test" + } + ], + "temperature": 1.0, + "top_p": 1.0 + }`, + validate: func(t *testing.T, req *OpenAIChatRequest) { + if req.Model != "gpt-4" { + t.Errorf("Expected Model to be 'gpt-4', got %q", req.Model) + } + + if len(req.Messages) != 1 { + t.Fatalf("Expected 1 message, got %d", len(req.Messages)) + } + + // Optional fields should be nil/empty when omitted + if req.Stream != nil { + t.Error("Expected Stream to be nil when omitted") + } + + if req.MaxTokens != nil { + t.Error("Expected MaxTokens to be nil when omitted") + } + + if len(req.Fallbacks) != 0 { + t.Errorf("Expected Fallbacks to be empty when omitted, got %v", req.Fallbacks) + } + + // ChatParameters fields should still be populated + if req.Temperature == nil || *req.Temperature != 1.0 { + t.Errorf("Expected Temperature to be 1.0, got %v", req.Temperature) + } + + if req.TopP == nil || *req.TopP != 1.0 { + t.Errorf("Expected TopP to be 1.0, got %v", req.TopP) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req OpenAIChatRequest + if err := sonic.Unmarshal([]byte(tt.jsonPayload), &req); err != nil { + t.Fatalf("Failed to unmarshal JSON: %v", err) + } + + tt.validate(t, &req) + }) + } +} + +func TestOpenAIChatRequest_UnmarshalJSON_ChatParametersCustomLogic(t *testing.T) { + tests := []struct { + name string + jsonPayload string + validate func(t *testing.T, req *OpenAIChatRequest) + expectError bool + }{ + { + name: "reasoning_effort converted to Reasoning.Effort", + jsonPayload: `{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "Think step by step" + } + ], + "reasoning_effort": "high", + "temperature": 0.8 + }`, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // Assert base fields are preserved + if req.Model != "gpt-4o" { + t.Errorf("Expected Model to be 'gpt-4o', got %q", req.Model) + } + + // Assert reasoning_effort was converted to Reasoning.Effort + if req.Reasoning == nil { + t.Fatal("Expected Reasoning to be set from reasoning_effort") + } + if req.Reasoning.Effort == nil { + t.Fatal("Expected Reasoning.Effort to be set") + } + if *req.Reasoning.Effort != "high" { + t.Errorf("Expected Reasoning.Effort to be 'high', got %q", *req.Reasoning.Effort) + } + + // Assert other ChatParameters fields are still populated + if req.Temperature == nil || *req.Temperature != 0.8 { + t.Errorf("Expected Temperature to be 0.8, got %v", req.Temperature) + } + }, + expectError: false, + }, + { + name: "both reasoning and reasoning_effort should error", + jsonPayload: `{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "Test" + } + ], + "reasoning": { + "effort": "medium" + }, + "reasoning_effort": "high" + }`, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // This should have failed during unmarshaling + }, + expectError: true, + }, + { + name: "reasoning_effort with multiple ChatParameters fields", + jsonPayload: `{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "Analyze this" + } + ], + "reasoning_effort": "medium", + "temperature": 0.6, + "top_p": 0.85, + "max_completion_tokens": 2000 + }`, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // Assert base fields + if req.Model != "gpt-4o" { + t.Errorf("Expected Model to be 'gpt-4o', got %q", req.Model) + } + + // Assert reasoning_effort conversion + if req.Reasoning == nil || req.Reasoning.Effort == nil { + t.Fatal("Expected Reasoning.Effort to be set from reasoning_effort") + } + if *req.Reasoning.Effort != "medium" { + t.Errorf("Expected Reasoning.Effort to be 'medium', got %q", *req.Reasoning.Effort) + } + + // Assert other ChatParameters fields + if req.Temperature == nil || *req.Temperature != 0.6 { + t.Errorf("Expected Temperature to be 0.6, got %v", req.Temperature) + } + if req.TopP == nil || *req.TopP != 0.85 { + t.Errorf("Expected TopP to be 0.85, got %v", req.TopP) + } + if req.MaxCompletionTokens == nil || *req.MaxCompletionTokens != 2000 { + t.Errorf("Expected MaxCompletionTokens to be 2000, got %v", req.MaxCompletionTokens) + } + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req OpenAIChatRequest + err := sonic.Unmarshal([]byte(tt.jsonPayload), &req) + + if tt.expectError { + if err == nil { + t.Error("Expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error during unmarshaling: %v", err) + } + + tt.validate(t, &req) + }) + } +} + +func TestOpenAIChatRequest_UnmarshalJSON_PresenceAssertions(t *testing.T) { + // Test that verifies presence of fields (not just values) + jsonPayload := `{ + "model": "gpt-4o-mini", + "messages": [ + { + "role": "assistant", + "content": "Hello!" + } + ], + "stream": false, + "max_tokens": 150, + "fallbacks": ["model1", "model2"], + "temperature": 0.3, + "top_p": 0.7, + "user": "test-user-123" + }` + + var req OpenAIChatRequest + if err := sonic.Unmarshal([]byte(jsonPayload), &req); err != nil { + t.Fatalf("Failed to unmarshal JSON: %v", err) + } + + // Presence assertions for base fields + if req.Model == "" { + t.Error("Model field should be present") + } + + if len(req.Messages) == 0 { + t.Error("Messages field should be present and non-empty") + } + + if req.Stream == nil { + t.Error("Stream field should be present (even if false)") + } + + if req.MaxTokens == nil { + t.Error("MaxTokens field should be present") + } + + if len(req.Fallbacks) == 0 { + t.Error("Fallbacks field should be present and non-empty") + } + + // Presence assertions for ChatParameters fields + if req.Temperature == nil { + t.Error("Temperature field should be present") + } + + if req.TopP == nil { + t.Error("TopP field should be present") + } + + if req.User == nil { + t.Error("User field should be present") + } +} + +func TestOpenAIChatRequest_UnmarshalJSON_ValueAssertions(t *testing.T) { + // Test that verifies exact values match expectations + jsonPayload := `{ + "model": "gpt-4-turbo", + "messages": [ + { + "role": "system", + "content": "System message" + }, + { + "role": "user", + "content": "User message" + } + ], + "stream": true, + "max_tokens": 250, + "fallbacks": ["fallback1"], + "temperature": 0.9, + "top_p": 0.95, + "seed": 12345, + "stop": ["END", "STOP"] + }` + + var req OpenAIChatRequest + if err := sonic.Unmarshal([]byte(jsonPayload), &req); err != nil { + t.Fatalf("Failed to unmarshal JSON: %v", err) + } + + // Value assertions for base fields + if req.Model != "gpt-4-turbo" { + t.Errorf("Expected Model value 'gpt-4-turbo', got %q", req.Model) + } + + if len(req.Messages) != 2 { + t.Fatalf("Expected 2 messages, got %d", len(req.Messages)) + } + if req.Messages[0].Role != schemas.ChatMessageRoleSystem { + t.Errorf("Expected first message role 'system', got %q", req.Messages[0].Role) + } + if req.Messages[1].Role != schemas.ChatMessageRoleUser { + t.Errorf("Expected second message role 'user', got %q", req.Messages[1].Role) + } + + if req.Stream == nil || !*req.Stream { + t.Error("Expected Stream value to be true") + } + + if req.MaxTokens == nil || *req.MaxTokens != 250 { + t.Errorf("Expected MaxTokens value 250, got %v", req.MaxTokens) + } + + if len(req.Fallbacks) != 1 || req.Fallbacks[0] != "fallback1" { + t.Errorf("Expected Fallbacks value ['fallback1'], got %v", req.Fallbacks) + } + + // Value assertions for ChatParameters fields + if req.Temperature == nil || *req.Temperature != 0.9 { + t.Errorf("Expected Temperature value 0.9, got %v", req.Temperature) + } + + if req.TopP == nil || *req.TopP != 0.95 { + t.Errorf("Expected TopP value 0.95, got %v", req.TopP) + } + + if req.Seed == nil || *req.Seed != 12345 { + t.Errorf("Expected Seed value 12345, got %v", req.Seed) + } + + if len(req.Stop) != 2 || req.Stop[0] != "END" || req.Stop[1] != "STOP" { + t.Errorf("Expected Stop value ['END', 'STOP'], got %v", req.Stop) + } +} + diff --git a/core/providers/openai/utils.go b/core/providers/openai/utils.go index 71906f014..c7ef1ecec 100644 --- a/core/providers/openai/utils.go +++ b/core/providers/openai/utils.go @@ -43,3 +43,14 @@ func ConvertBifrostMessagesToOpenAIMessages(messages []schemas.ChatMessage) []Op } return openaiMessages } + +// OpenAI enforces a 64 character maximum on the user field +const MaxUserFieldLength = 64 + +// SanitizeUserField returns nil if user exceeds MaxUserFieldLength, otherwise returns the original value +func SanitizeUserField(user *string) *string { + if user != nil && len(*user) > MaxUserFieldLength { + return nil + } + return user +} diff --git a/core/providers/openrouter/openrouter_test.go b/core/providers/openrouter/openrouter_test.go index 82fdda03a..d210e3fe4 100644 --- a/core/providers/openrouter/openrouter_test.go +++ b/core/providers/openrouter/openrouter_test.go @@ -28,7 +28,7 @@ func TestOpenRouter(t *testing.T) { VisionModel: "openai/gpt-4o", TextModel: "google/gemini-2.5-flash", EmbeddingModel: "", - ReasoningModel: "openai/o1", + ReasoningModel: "openai/gpt-oss-120b", Scenarios: testutil.TestScenarios{ TextCompletion: true, SimpleChat: true, diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index ee80dcd8b..ba86f5cb6 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net/http" "net/textproto" "net/url" @@ -263,7 +264,7 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette if convertedBody == nil { return nil, NewBifrostOperationError("request body is not provided", nil, providerType) } - jsonBody, err := sonic.Marshal(convertedBody) + jsonBody, err := sonic.MarshalIndent(convertedBody, "", " ") if err != nil { return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerType) } @@ -318,6 +319,20 @@ func HandleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif statusCode := resp.StatusCode() body := append([]byte(nil), resp.Body()...) + // decode body + decodedBody, err := CheckAndDecodeBody(resp) + if err != nil { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: &schemas.ErrorField{ + Message: err.Error(), + }, + } + } + + body = decodedBody + if err := sonic.Unmarshal(body, errorResp); err != nil { rawResponse := body message := fmt.Sprintf("provider API error: %s", string(rawResponse)) @@ -1007,3 +1022,100 @@ func HandleMultipleListModelsRequests( return response, nil } + +// GetRandomString generates a random alphanumeric string of the given length. +func GetRandomString(length int) string { + if length <= 0 { + return "" + } + randomSource := rand.New(rand.NewSource(time.Now().UnixNano())) + letters := []rune("abcdefghijklmnopqrstuvwxyz0123456789") + b := make([]rune, length) + for i := range b { + b[i] = letters[randomSource.Intn(len(letters))] + } + return string(b) +} + +// GetReasoningEffortFromBudgetTokens maps a reasoning token budget to OpenAI reasoning effort. +// Valid values: none, low, medium, high +func GetReasoningEffortFromBudgetTokens( + budgetTokens int, + minBudgetTokens int, + maxTokens int, +) string { + if budgetTokens <= 0 { + return "none" + } + + // Defensive defaults + if maxTokens <= 0 { + return "medium" + } + + // Normalize budget + if budgetTokens < minBudgetTokens { + budgetTokens = minBudgetTokens + } + if budgetTokens > maxTokens { + budgetTokens = maxTokens + } + + // Avoid division by zero + if maxTokens <= minBudgetTokens { + return "high" + } + + ratio := float64(budgetTokens-minBudgetTokens) / float64(maxTokens-minBudgetTokens) + + switch { + case ratio <= 0.25: + return "low" + case ratio <= 0.60: + return "medium" + default: + return "high" + } +} + +// GetBudgetTokensFromReasoningEffort converts OpenAI reasoning effort +// into a reasoning token budget. +// effort ∈ {"none", "minimal", "low", "medium", "high"} +func GetBudgetTokensFromReasoningEffort( + effort string, + minBudgetTokens int, + maxTokens int, +) (int, error) { + if effort == "none" { + return 0, nil + } + + if minBudgetTokens > maxTokens { + return 0, fmt.Errorf("max_tokens must be greater than %d for reasoning", minBudgetTokens) + } + + // Defensive defaults + if maxTokens <= minBudgetTokens { + return minBudgetTokens, nil + } + + var ratio float64 + + switch effort { + case "minimal": + ratio = 0.025 + case "low": + ratio = 0.15 + case "medium": + ratio = 0.425 + case "high": + ratio = 0.80 + default: + // Unknown effort → safe default + ratio = 0.425 + } + + budget := minBudgetTokens + int(ratio*float64(maxTokens-minBudgetTokens)) + + return budget, nil +} diff --git a/core/providers/vertex/errors.go b/core/providers/vertex/errors.go index 11c2b4f00..b8effae16 100644 --- a/core/providers/vertex/errors.go +++ b/core/providers/vertex/errors.go @@ -10,15 +10,21 @@ import ( func parseVertexError(providerName schemas.ModelProvider, resp *fasthttp.Response) *schemas.BifrostError { var openAIErr schemas.BifrostError var vertexErr []VertexError - if err := sonic.Unmarshal(resp.Body(), &openAIErr); err != nil || openAIErr.Error == nil { + + decodedBody, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + } + + if err := sonic.Unmarshal(decodedBody, &openAIErr); err != nil || openAIErr.Error == nil { // Try Vertex error format if OpenAI format fails or is incomplete - if err := sonic.Unmarshal(resp.Body(), &vertexErr); err != nil { + if err := sonic.Unmarshal(decodedBody, &vertexErr); err != nil { //try with single Vertex error format var vertexErr VertexError - if err := sonic.Unmarshal(resp.Body(), &vertexErr); err != nil { + if err := sonic.Unmarshal(decodedBody, &vertexErr); err != nil { // Try VertexValidationError format (validation errors from Mistral endpoint) var validationErr VertexValidationError - if err := sonic.Unmarshal(resp.Body(), &validationErr); err != nil { + if err := sonic.Unmarshal(decodedBody, &validationErr); err != nil { return providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) } if len(validationErr.Detail) > 0 { diff --git a/core/providers/vertex/types.go b/core/providers/vertex/types.go index bbcc87740..1a5e3b3a8 100644 --- a/core/providers/vertex/types.go +++ b/core/providers/vertex/types.go @@ -74,7 +74,6 @@ type VertexAdvancedVoiceOptions struct { LowLatencyJourneySynthesis bool `json:"lowLatencyJourneySynthesis,omitempty"` } - // VertexEmbeddingInstance represents a single embedding instance in the request type VertexEmbeddingInstance struct { Content string `json:"content"` // The text to generate embeddings for diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go new file mode 100644 index 000000000..4cf77138c --- /dev/null +++ b/core/providers/vertex/utils.go @@ -0,0 +1,86 @@ +package vertex + +import ( + "context" + "fmt" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/providers/anthropic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" + "github.com/maximhq/bifrost/core/schemas" +) + +func getRequestBodyForAnthropicResponses(ctx context.Context, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { + var jsonBody []byte + var err error + + // Check if raw request body should be used + if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody { + jsonBody = request.GetRawRequestBody() + // Unmarshal and check if model and region are present + var requestBody map[string]interface{} + if err := sonic.Unmarshal(jsonBody, &requestBody); err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, fmt.Errorf("failed to unmarshal request body: %w", err), providerName) + } + // Add max_tokens if not present + if _, exists := requestBody["max_tokens"]; !exists { + requestBody["max_tokens"] = anthropic.AnthropicDefaultMaxTokens + } + delete(requestBody, "model") + delete(requestBody, "region") + // Add anthropic_version if not present + if _, exists := requestBody["anthropic_version"]; !exists { + requestBody["anthropic_version"] = DefaultVertexAnthropicVersion + } + // Add stream if not present + if isStreaming { + requestBody["stream"] = true + } + jsonBody, err = sonic.Marshal(requestBody) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + } + } else { + // Convert request to Anthropic format + reqBody, err := anthropic.ToAnthropicResponsesRequest(request) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerName) + } + if reqBody == nil { + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + } + + // Set deployment as model + reqBody.Model = deployment + if isStreaming { + reqBody.Stream = schemas.Ptr(true) + } + + // Convert struct to map for Vertex API + reqBytes, err := sonic.Marshal(reqBody) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + } + + var requestBody map[string]interface{} + if err := sonic.Unmarshal(reqBytes, &requestBody); err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, fmt.Errorf("failed to unmarshal request body: %w", err), providerName) + } + + // Add anthropic_version if not present + if _, exists := requestBody["anthropic_version"]; !exists { + requestBody["anthropic_version"] = DefaultVertexAnthropicVersion + } + + // Remove fields not needed by Vertex API + delete(requestBody, "model") + delete(requestBody, "region") + + jsonBody, err = sonic.Marshal(requestBody) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + } + } + + return jsonBody, nil +} diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index d63b91308..636fe8648 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -701,38 +701,7 @@ func (provider *VertexProvider) Responses(ctx context.Context, key schemas.Key, deployment := provider.getModelDeployment(key, request.Model) if schemas.IsAnthropicModel(deployment) { - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( - ctx, - request, - func() (any, error) { - //TODO: optimize this double Marshal - // Format messages for Vertex API - var requestBody map[string]interface{} - - // Use centralized Anthropic converter - reqBody, err := anthropic.ToAnthropicResponsesRequest(request) - if err != nil { - return nil, err - } - if reqBody != nil { - reqBody.Model = deployment - } - // Convert struct to map for Vertex API - reqBytes, err := sonic.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request body: %w", err) - } - if err := sonic.Unmarshal(reqBytes, &requestBody); err != nil { - return nil, fmt.Errorf("failed to unmarshal request body: %w", err) - } - if _, exists := requestBody["anthropic_version"]; !exists { - requestBody["anthropic_version"] = DefaultVertexAnthropicVersion - } - delete(requestBody, "model") - delete(requestBody, "region") - return requestBody, nil - }, - provider.GetProviderKey()) + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false) if bifrostErr != nil { return nil, bifrostErr } @@ -869,38 +838,7 @@ func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRun return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) } - // Use Anthropic-style streaming for Claude models - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( - ctx, - request, - func() (any, error) { - reqBody, err := anthropic.ToAnthropicResponsesRequest(request) - if err != nil { - return nil, err - } - if reqBody != nil { - reqBody.Model = deployment - reqBody.Stream = schemas.Ptr(true) - } - // Convert struct to map for Vertex API - reqBytes, err := sonic.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request body: %w", err) - } - var requestBody map[string]interface{} - if err := sonic.Unmarshal(reqBytes, &requestBody); err != nil { - return nil, fmt.Errorf("failed to unmarshal request body: %w", err) - } - - if _, exists := requestBody["anthropic_version"]; !exists { - requestBody["anthropic_version"] = DefaultVertexAnthropicVersion - } - - delete(requestBody, "model") - delete(requestBody, "region") - return requestBody, nil - }, - provider.GetProviderKey()) + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, true) if bifrostErr != nil { return nil, bifrostErr } @@ -943,7 +881,7 @@ func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRun ctx, provider.client, url, - jsonData, + jsonBody, headers, provider.networkConfig.ExtraHeaders, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), diff --git a/core/providers/vertex/vertex_test.go b/core/providers/vertex/vertex_test.go index 2e6feb31f..0ce8cc64e 100644 --- a/core/providers/vertex/vertex_test.go +++ b/core/providers/vertex/vertex_test.go @@ -28,7 +28,7 @@ func TestVertex(t *testing.T) { VisionModel: "google/gemini-2.0-flash-001", TextModel: "", // Vertex doesn't support text completion in newer models EmbeddingModel: "text-multilingual-embedding-002", - // ReasoningModel: "google/gemini-2.5-pro", + ReasoningModel: "claude-4.5-haiku", Scenarios: testutil.TestScenarios{ TextCompletion: false, // Not supported SimpleChat: true, diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 62b308589..7b587480a 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -117,8 +117,10 @@ const ( BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool + BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost) BifrostContextKeyStructuredOutputToolName BifrostContextKey = "bifrost-structured-output-tool-name" // string (to store the name of the structured output tool (set by bifrost)) + BifrostContextKeyUserAgent BifrostContextKey = "bifrost-user-agent" // string (set by bifrost) ) // NOTE: for custom plugin implementation dealing with streaming short circuit, diff --git a/core/schemas/mux.go b/core/schemas/mux.go index 14423dc0c..d4b10da53 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -896,7 +896,6 @@ func (cr *BifrostChatResponse) ToBifrostResponsesResponse() *BifrostResponsesRes responsesMessages := choice.ChatNonStreamResponseChoice.Message.ToResponsesMessages() outputMessages = append(outputMessages, responsesMessages...) } - // Note: Stream choices would need different handling if needed } if len(outputMessages) > 0 { diff --git a/core/schemas/responses.go b/core/schemas/responses.go index d9e7cef6c..3df6d9618 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -66,6 +66,7 @@ type BifrostResponsesResponse struct { ServiceTier *string `json:"service_tier,omitempty"` Status *string `json:"status,omitempty"` // completed, failed, in_progress, cancelled, queued, or incomplete StreamOptions *ResponsesStreamOptions `json:"stream_options,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` // Not in OpenAI's spec, but sent by other providers Store *bool `json:"store,omitempty"` Temperature *float64 `json:"temperature,omitempty"` Text *ResponsesTextConfig `json:"text,omitempty"` @@ -396,9 +397,10 @@ const ( // ResponsesMessageContentBlock represents different types of content (text, image, file, audio) // Only one of the content type fields should be set type ResponsesMessageContentBlock struct { - Type ResponsesMessageContentBlockType `json:"type"` - FileID *string `json:"file_id,omitempty"` // Reference to uploaded file - Text *string `json:"text,omitempty"` + Type ResponsesMessageContentBlockType `json:"type"` + FileID *string `json:"file_id,omitempty"` // Reference to uploaded file + Text *string `json:"text,omitempty"` + Signature *string `json:"signature,omitempty"` // Signature of the content (for reasoning) *ResponsesInputMessageContentBlockImage *ResponsesInputMessageContentBlockFile @@ -727,7 +729,7 @@ func (rf *ResponsesFunctionToolCallOutput) UnmarshalJSON(data []byte) error { // ResponsesReasoning represents a reasoning output type ResponsesReasoning struct { - Summary []ResponsesReasoningContent `json:"summary"` + Summary []ResponsesReasoningSummary `json:"summary"` EncryptedContent *string `json:"encrypted_content,omitempty"` } @@ -739,8 +741,8 @@ const ( ResponsesReasoningContentBlockTypeSummaryText ResponsesReasoningContentBlockType = "summary_text" ) -// ResponsesReasoningContent represents a reasoning content block -type ResponsesReasoningContent struct { +// ResponsesReasoningSummary represents a reasoning content block +type ResponsesReasoningSummary struct { Type ResponsesReasoningContentBlockType `json:"type"` Text string `json:"text"` } @@ -1435,8 +1437,9 @@ type BifrostResponsesStreamResponse struct { ItemID *string `json:"item_id,omitempty"` Part *ResponsesMessageContentBlock `json:"part,omitempty"` - Delta *string `json:"delta,omitempty"` - LogProbs []ResponsesOutputMessageContentTextLogProb `json:"logprobs,omitempty"` + Delta *string `json:"delta,omitempty"` + Signature *string `json:"signature,omitempty"` // Not in OpenAI's spec, but sent by other providers + LogProbs []ResponsesOutputMessageContentTextLogProb `json:"logprobs,omitempty"` Text *string `json:"text,omitempty"` // Full text of the output item, comes with event "response.output_text.done" diff --git a/core/schemas/utils.go b/core/schemas/utils.go index 7fc979ff7..a450b3ed5 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -1039,12 +1039,12 @@ func deepCopyResponsesMessageContentBlock(original ResponsesMessageContentBlock) return copy } -// IsAnthropicModel checks if the model is an Anthropic model in Vertex. +// IsAnthropicModel checks if the model is an Anthropic model. func IsAnthropicModel(model string) bool { return strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") } -// IsMistralModel checks if the model is a Mistral or Codestral model in Vertex. +// IsMistralModel checks if the model is a Mistral or Codestral model. func IsMistralModel(model string) bool { return strings.Contains(model, "mistral") || strings.Contains(model, "codestral") } diff --git a/docs/docs.json b/docs/docs.json index 97ae603b5..5e1bd31f7 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -70,13 +70,6 @@ } ] }, - { - "group": "Models Catalog", - "icon": "box", - "pages": [ - "models-catalog/list" - ] - }, { "group": "Provider Integrations", "icon": "plug", diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 4fbbd0535..17e5510a4 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -1675,7 +1675,6 @@ func (s *RDBConfigStore) UpdateBudgets(ctx context.Context, budgets []*tables.Ta } else { txDB = s.db } - s.logger.Debug("updating budgets: %+v", budgets) for _, b := range budgets { if err := txDB.WithContext(ctx).Save(b).Error; err != nil { return s.parseGormError(err) diff --git a/framework/streaming/chat.go b/framework/streaming/chat.go index ed1b8055f..185a45019 100644 --- a/framework/streaming/chat.go +++ b/framework/streaming/chat.go @@ -255,13 +255,6 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, chunk.Delta = copied chunk.FinishReason = choice.FinishReason } - if choice.TextCompletionResponseChoice != nil { - deltaCopy := choice.TextCompletionResponseChoice.Text - chunk.Delta = &schemas.ChatStreamResponseChoiceDelta{ - Content: deltaCopy, - } - chunk.FinishReason = choice.FinishReason - } } // Extract token usage if result.ChatResponse.Usage != nil && result.ChatResponse.Usage.TotalTokens > 0 { diff --git a/framework/streaming/responses.go b/framework/streaming/responses.go index dfff77f8b..ca418573c 100644 --- a/framework/streaming/responses.go +++ b/framework/streaming/responses.go @@ -494,6 +494,44 @@ func (a *Accumulator) buildCompleteMessageFromResponsesStreamChunks(chunks []*Re if resp.Delta != nil && len(messages) > 0 { a.appendFunctionArgumentsDeltaToResponsesMessage(&messages[len(messages)-1], *resp.Delta) } + + case schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta: + // Create new reasoning message if none exists, or find existing reasoning message to append delta to + if (resp.Delta != nil || resp.Signature != nil) && resp.ItemID != nil { + var targetMessage *schemas.ResponsesMessage + + // Find the reasoning message by ItemID + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].ID != nil && *messages[i].ID == *resp.ItemID { + targetMessage = &messages[i] + break + } + } + + // If no message found, create a new reasoning message + if targetMessage == nil { + newMessage := schemas.ResponsesMessage{ + ID: resp.ItemID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesReasoning: &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + }, + } + messages = append(messages, newMessage) + targetMessage = &messages[len(messages)-1] + } + + // Handle text delta + if resp.Delta != nil { + a.appendReasoningDeltaToResponsesMessage(targetMessage, *resp.Delta, resp.ContentIndex) + } + + // Handle signature delta + if resp.Signature != nil { + a.appendReasoningSignatureToResponsesMessage(targetMessage, *resp.Signature, resp.ContentIndex) + } + } } } @@ -585,6 +623,109 @@ func (a *Accumulator) appendFunctionArgumentsDeltaToResponsesMessage(message *sc } } +// appendReasoningDeltaToResponsesMessage appends reasoning delta to a responses message +func (a *Accumulator) appendReasoningDeltaToResponsesMessage(message *schemas.ResponsesMessage, delta string, contentIndex *int) { + // Handle reasoning content in two ways: + // 1. Content blocks (for reasoning_text content blocks) + // 2. ResponsesReasoning.Summary (for reasoning summary accumulation) + + // If we have a content index, this is reasoning content in content blocks + if contentIndex != nil { + if message.Content == nil { + message.Content = &schemas.ResponsesMessageContent{} + } + + // If we don't have content blocks yet, create them + if message.Content.ContentBlocks == nil { + message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, *contentIndex+1) + } + + // Ensure we have enough content blocks + for len(message.Content.ContentBlocks) <= *contentIndex { + message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{}) + } + + // Initialize the content block if needed + if message.Content.ContentBlocks[*contentIndex].Type == "" { + message.Content.ContentBlocks[*contentIndex].Type = schemas.ResponsesOutputMessageContentTypeReasoning + } + + // Append to existing reasoning text or create new text + if message.Content.ContentBlocks[*contentIndex].Text == nil { + message.Content.ContentBlocks[*contentIndex].Text = &delta + } else { + *message.Content.ContentBlocks[*contentIndex].Text += delta + } + } else { + // No content index - this is reasoning summary accumulation + if message.ResponsesReasoning == nil { + message.ResponsesReasoning = &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + } + } + + // For now, accumulate into a single summary entry + // In the future, this could be enhanced to handle multiple summary entries + if len(message.ResponsesReasoning.Summary) == 0 { + message.ResponsesReasoning.Summary = append(message.ResponsesReasoning.Summary, schemas.ResponsesReasoningSummary{ + Type: schemas.ResponsesReasoningContentBlockTypeSummaryText, + Text: delta, + }) + } else { + // Append to the first (and typically only) summary entry + message.ResponsesReasoning.Summary[0].Text += delta + } + } +} + +// appendReasoningSignatureToResponsesMessage appends reasoning signature to a responses message +func (a *Accumulator) appendReasoningSignatureToResponsesMessage(message *schemas.ResponsesMessage, signature string, contentIndex *int) { + // Handle signature content in content blocks or ResponsesReasoning.EncryptedContent + + // If we have a content index, this is signature content in content blocks + if contentIndex != nil { + if message.Content == nil { + message.Content = &schemas.ResponsesMessageContent{} + } + + // If we don't have content blocks yet, create them + if message.Content.ContentBlocks == nil { + message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, *contentIndex+1) + } + + // Ensure we have enough content blocks + for len(message.Content.ContentBlocks) <= *contentIndex { + message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{}) + } + + // Initialize the content block if needed + if message.Content.ContentBlocks[*contentIndex].Type == "" { + message.Content.ContentBlocks[*contentIndex].Type = schemas.ResponsesOutputMessageContentTypeReasoning + } + + // Set or append signature to the content block + if message.Content.ContentBlocks[*contentIndex].Signature == nil { + message.Content.ContentBlocks[*contentIndex].Signature = &signature + } else { + *message.Content.ContentBlocks[*contentIndex].Signature += signature + } + } else { + // No content index - this is encrypted content at the reasoning level + if message.ResponsesReasoning == nil { + message.ResponsesReasoning = &schemas.ResponsesReasoning{ + Summary: []schemas.ResponsesReasoningSummary{}, + } + } + + // Set or append to encrypted content + if message.ResponsesReasoning.EncryptedContent == nil { + message.ResponsesReasoning.EncryptedContent = &signature + } else { + *message.ResponsesReasoning.EncryptedContent += signature + } + } +} + // processAccumulatedResponsesStreamingChunks processes all accumulated responses streaming chunks in order func (a *Accumulator) processAccumulatedResponsesStreamingChunks(requestID string, respErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) { accumulator := a.getOrCreateStreamAccumulator(requestID) diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index bed9ca07e..546d1285c 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -221,6 +221,38 @@ func (r *ResponsesRequestInput) UnmarshalJSON(data []byte) error { return fmt.Errorf("invalid responses request input") } +// UnmarshalJSON implements custom JSON unmarshalling for ResponsesRequest. +// This is needed because ResponsesParameters has a custom UnmarshalJSON method, +// which interferes with sonic's handling of the embedded BifrostParams struct. +func (rr *ResponsesRequest) UnmarshalJSON(data []byte) error { + // First, unmarshal BifrostParams fields directly + type bifrostAlias BifrostParams + var bp bifrostAlias + if err := sonic.Unmarshal(data, &bp); err != nil { + return err + } + rr.BifrostParams = BifrostParams(bp) + + // Unmarshal messages + var inputStruct struct { + Input ResponsesRequestInput `json:"input"` + } + if err := sonic.Unmarshal(data, &inputStruct); err != nil { + return err + } + rr.Input = inputStruct.Input + + // Unmarshal ResponsesParameters (which has its own custom unmarshaller) + if rr.ResponsesParameters == nil { + rr.ResponsesParameters = &schemas.ResponsesParameters{} + } + if err := sonic.Unmarshal(data, rr.ResponsesParameters); err != nil { + return err + } + + return nil +} + // ResponsesRequest is a bifrost responses request type ResponsesRequest struct { Input ResponsesRequestInput `json:"input"` diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index 0c49a96d5..8cd0d87d3 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -92,7 +92,7 @@ func TransportInterceptorMiddleware(config *lib.Config) lib.BifrostHTTPMiddlewar } for _, plugin := range plugins { // Call TransportInterceptor on all plugins - pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second) + pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second) modifiedHeaders, modifiedBody, err := plugin.TransportInterceptor(pluginCtx, string(ctx.Request.URI().RequestURI()), headers, requestBody) cancel() if err != nil { diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index db44cf6ed..b7115027a 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -4,9 +4,11 @@ import ( "context" "errors" "fmt" + "log" "strconv" "strings" + "github.com/bytedance/sonic" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/providers/anthropic" "github.com/maximhq/bifrost/core/schemas" @@ -38,6 +40,11 @@ func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { return nil, errors.New("invalid request type") }, TextResponseConverter: func(ctx *context.Context, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { + if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment) { + if resp.ExtraFields.RawResponse != nil { + return resp.ExtraFields.RawResponse, nil + } + } return anthropic.ToAnthropicTextCompletionResponse(resp), nil }, ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { @@ -63,13 +70,13 @@ func createAnthropicMessagesRouteConfig(pathPrefix string) []RouteConfig { RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok { return &schemas.BifrostRequest{ - ResponsesRequest: anthropicReq.ToBifrostResponsesRequest(), + ResponsesRequest: anthropicReq.ToBifrostResponsesRequest(*ctx), }, nil } return nil, errors.New("invalid request type") }, ResponsesResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesResponse) (interface{}, error) { - if resp.ExtraFields.Provider == schemas.Anthropic { + if isClaudeModel(resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment, string(resp.ExtraFields.Provider)) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } @@ -81,24 +88,42 @@ func createAnthropicMessagesRouteConfig(pathPrefix string) []RouteConfig { }, StreamConfig: &StreamConfig{ ResponsesStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { - anthropicResponse := anthropic.ToAnthropicResponsesStreamResponse(resp) - // Should never happen, but just in case - if anthropicResponse == nil { + if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment) { + if resp.ExtraFields.RawResponse != nil { + raw, ok := resp.ExtraFields.RawResponse.(string) + if !ok { + return "", nil, fmt.Errorf("expected RawResponse string, got %T", resp.ExtraFields.RawResponse) + } + var rawResponseJSON anthropic.AnthropicStreamEvent + if err := sonic.Unmarshal([]byte(raw), &rawResponseJSON); err == nil { + return string(rawResponseJSON.Type), raw, nil + } + } return "", nil, nil } - if resp.ExtraFields.Provider == schemas.Anthropic { - // This is always true in integrations - isRawResponseEnabled, ok := (*ctx).Value(schemas.BifrostContextKeySendBackRawResponse).(bool) - if ok && isRawResponseEnabled { - if resp.ExtraFields.RawResponse != nil { - return string(anthropicResponse.Type), resp.ExtraFields.RawResponse, nil - } else { - // Explicitly return nil to indicate that no raw response is available (because 1 chunk of anthropic gets converted to multiple bifrost responses chunks) - return "", nil, nil + anthropicResponse := anthropic.ToAnthropicResponsesStreamResponse(*ctx, resp) + // Can happen for openai lifecycle events + if len(anthropicResponse) == 0 { + return "", nil, nil + } else { + if len(anthropicResponse) > 1 { + combinedContent := "" + for _, event := range anthropicResponse { + responseJSON, err := sonic.Marshal(event) + if err != nil { + // Log JSON marshaling error but continue processing (should not happen) + log.Printf("Failed to marshal streaming response: %v", err) + continue + } + combinedContent += fmt.Sprintf("event: %s\ndata: %s\n\n", event.Type, responseJSON) } + return "", combinedContent, nil + } else if len(anthropicResponse) == 1 { + return string(anthropicResponse[0].Type), anthropicResponse[0], nil + } else { + return "", nil, nil } } - return string(anthropicResponse.Type), anthropicResponse, nil }, ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicResponsesStreamError(err) @@ -168,28 +193,57 @@ func checkAnthropicPassthrough(ctx *fasthttp.RequestCtx, bifrostCtx *context.Con } } - if !strings.Contains(model, "claude") || (provider != schemas.Anthropic && provider != "") { - // Not a Claude model or not an Anthropic model, so we can continue - return nil + headers := extractHeadersFromRequest(ctx) + if len(headers) > 0 { + // Check for User-Agent header (case-insensitive) + var userAgent []string + for key, value := range headers { + if strings.EqualFold(key, "user-agent") { + userAgent = value + break + } + } + if len(userAgent) > 0 { + // Check if it's claude code + if strings.Contains(userAgent[0], "claude-cli") { + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyUserAgent, "claude-cli") + } + } } // Check if anthropic oauth headers are present - if !isAnthropicAPIKeyAuth(ctx) { - headers := extractHeadersFromRequest(ctx) - url := extractExactPath(ctx) - if !strings.HasPrefix(url, "/") { - url = "/" + url - } - - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyExtraHeaders, headers) - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyURLPath, url) - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeySkipKeySelection, true) + if shouldUsePassthrough(bifrostCtx, provider, model, "") { *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyUseRawRequestBody, true) + if !isAnthropicAPIKeyAuth(ctx) && (provider == schemas.Anthropic || provider == "") { + url := extractExactPath(ctx) + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyExtraHeaders, headers) + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyURLPath, url) + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeySkipKeySelection, true) + } } - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKey("is_anthropic_passthrough"), true) return nil } +func shouldUsePassthrough(ctx *context.Context, provider schemas.ModelProvider, model string, deployment string) bool { + isClaudeCode := false + if userAgent, ok := (*ctx).Value(schemas.BifrostContextKeyUserAgent).(string); ok { + if strings.Contains(userAgent, "claude-cli") { + isClaudeCode = true + } + } + return isClaudeCode && isClaudeModel(model, deployment, string(provider)) +} + +func isClaudeModel(model, deployment, provider string) bool { + return (provider == string(schemas.Anthropic) || + (provider == "" && schemas.IsAnthropicModel(model))) || + (provider == string(schemas.Vertex) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(deployment))) || + (provider == string(schemas.Azure) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(deployment))) +} + // extractAnthropicListModelsParams extracts query parameters for list models request func extractAnthropicListModelsParams(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 4720224a0..974ff94c3 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -309,6 +309,9 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle // Set send back raw response flag for all integration requests *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeySendBackRawResponse, true) + // Set integration type to context + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyIntegrationType, string(config.Type)) + // Parse request body based on configuration if method != fasthttp.MethodGet { if config.RequestParser != nil { @@ -703,7 +706,10 @@ func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *co eventStreamEncoder = eventstream.NewEncoder() } - var shouldSendDoneMarker bool + shouldSendDoneMarker := true + if config.Type == RouteConfigTypeAnthropic || strings.Contains(config.Path, "/responses") { + shouldSendDoneMarker = false + } // Process streaming responses for chunk := range streamChan { @@ -800,10 +806,6 @@ func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *co convertedResponse, err = nil, fmt.Errorf("no response converter found for request type: %s", requestType) } - if eventType == "" { - shouldSendDoneMarker = true - } - if convertedResponse == nil && err == nil { // Skip streaming chunk if no response is available and no error is returned continue @@ -878,7 +880,7 @@ func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *co // CUSTOM SSE FORMAT: The converter returned a complete SSE string // This is used by providers like Anthropic that need custom event types // Example: "event: content_block_delta\ndata: {...}\n\n" - if !strings.HasPrefix(sseString, "data: ") { + if !strings.HasPrefix(sseString, "data: ") && !strings.HasPrefix(sseString, "event: ") { sseString = fmt.Sprintf("data: %s\n\n", sseString) } if _, err := fmt.Fprint(w, sseString); err != nil { diff --git a/ui/app/workspace/logs/views/columns.tsx b/ui/app/workspace/logs/views/columns.tsx index 579725597..c69d275cb 100644 --- a/ui/app/workspace/logs/views/columns.tsx +++ b/ui/app/workspace/logs/views/columns.tsx @@ -40,7 +40,7 @@ function getMessage(log?: LogEntry) { } else if (log?.speech_input) { return log.speech_input.input; } else if (log?.transcription_input) { - return log.transcription_input.prompt || "Audio file"; + return "Audio file"; } return ""; } diff --git a/ui/app/workspace/logs/views/logDetailsSheet.tsx b/ui/app/workspace/logs/views/logDetailsSheet.tsx index bc32c65c3..084c22a4b 100644 --- a/ui/app/workspace/logs/views/logDetailsSheet.tsx +++ b/ui/app/workspace/logs/views/logDetailsSheet.tsx @@ -184,6 +184,57 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet )} + {(() => { + const params = log.params as any; + const reasoning = params?.reasoning; + if (!reasoning || typeof reasoning !== "object" || Object.keys(reasoning).length === 0) { + return null; + } + return ( + <> + +
+ } /> +
+ {reasoning.effort && ( + + {reasoning.effort} + + } + /> + )} + {reasoning.summary && ( + + {reasoning.summary} + + } + /> + )} + {reasoning.generate_summary && ( + + {reasoning.generate_summary} + + } + /> + )} + {reasoning.max_tokens && } +
+
+ + ); + })()} {log.cache_debug && ( <> diff --git a/ui/app/workspace/logs/views/logResponsesMessageView.tsx b/ui/app/workspace/logs/views/logResponsesMessageView.tsx index 4dc577aa7..a1852de1c 100644 --- a/ui/app/workspace/logs/views/logResponsesMessageView.tsx +++ b/ui/app/workspace/logs/views/logResponsesMessageView.tsx @@ -199,6 +199,10 @@ const renderMessage = (message: ResponsesMessage, index: number) => { return message.role ? `${message.role.charAt(0).toUpperCase() + message.role.slice(1)}` : "Message"; }; + if (message.type == "reasoning" && (!message.summary || message.summary.length === 0) && !message.encrypted_content && !message.content) { + return null; + } + return (
@@ -266,7 +270,7 @@ const renderMessage = (message: ResponsesMessage, index: number) => { options={{ scrollBeyondLastLine: false, collapsibleBlocks: true, lineNumbers: "off", alwaysConsumeMouseWheel: false }} /> ) : ( -
{message.content}
+
{message.content}
)}
diff --git a/ui/app/workspace/logs/views/logResponsesOutputView.tsx b/ui/app/workspace/logs/views/logResponsesOutputView.tsx deleted file mode 100644 index 4dc577aa7..000000000 --- a/ui/app/workspace/logs/views/logResponsesOutputView.tsx +++ /dev/null @@ -1,347 +0,0 @@ -import { ResponsesMessage, ResponsesMessageContentBlock } from "@/lib/types/logs"; -import { CodeEditor } from "./codeEditor"; -import { isJson, cleanJson } from "@/lib/utils/validation"; - -interface LogResponsesMessageViewProps { - messages: ResponsesMessage[]; -} - -const renderContentBlock = (block: ResponsesMessageContentBlock, index: number) => { - const getBlockTitle = (type: string) => { - switch (type) { - case "input_text": - return "Input Text"; - case "input_image": - return "Input Image"; - case "input_file": - return "Input File"; - case "input_audio": - return "Input Audio"; - case "output_text": - return "Output Text"; - case "reasoning_text": - return "Reasoning Text"; - case "refusal": - return "Refusal"; - default: - return type.replace(/_/g, " ").replace(/\b\w/g, (l) => l.toUpperCase()); - } - }; - - return ( -
- {!block.text &&
{getBlockTitle(block.type)}
} - - {/* Handle text content */} - {block.text && ( -
- {isJson(block.text) ? ( - - ) : ( -
{block.text}
- )} -
- )} - - {/* Handle image content */} - {block.image_url && ( -
- -
- )} - - {/* Handle file content */} - {(block.file_id || block.file_data || block.file_url) && ( -
- -
- )} - - {/* Handle audio content */} - {block.input_audio && ( -
- -
- )} - - {/* Handle refusal content */} - {block.refusal && ( -
-
{block.refusal}
-
- )} - - {/* Handle annotations */} - {block.annotations && block.annotations.length > 0 && ( -
-
Annotations:
- -
- )} - - {/* Handle log probabilities */} - {block.logprobs && block.logprobs.length > 0 && ( -
-
Log Probabilities:
- -
- )} -
- ); -}; - -const renderMessage = (message: ResponsesMessage, index: number) => { - const getMessageTitle = () => { - if (message.type) { - switch (message.type) { - case "reasoning": - return "Reasoning"; - case "message": - return message.role ? `${message.role.charAt(0).toUpperCase() + message.role.slice(1)} Message` : "Message"; - case "function_call": - return `Function Call: ${message.name || "Unknown"}`; - case "function_call_output": - return "Function Call Output"; - case "file_search_call": - return "File Search"; - case "web_search_call": - return "Web Search"; - case "computer_call": - return "Computer Action"; - case "computer_call_output": - return "Computer Action Output"; - case "code_interpreter_call": - return "Code Interpreter"; - case "mcp_call": - return "MCP Tool Call"; - case "custom_tool_call": - return "Custom Tool Call"; - case "custom_tool_call_output": - return "Custom Tool Output"; - case "image_generation_call": - return "Image Generation"; - case "refusal": - return "Refusal"; - default: - return message.type.replace(/_/g, " ").replace(/\b\w/g, (l) => l.toUpperCase()); - } - } - return message.role ? `${message.role.charAt(0).toUpperCase() + message.role.slice(1)}` : "Message"; - }; - - return ( -
-
- {getMessageTitle()} - {/* {message.status && {message.status}} */} -
- - {/* Handle reasoning content */} - {message.type === "reasoning" && message.summary && message.summary.length > 0 && ( -
- {message.summary.every((item) => item.type === "summary_text") ? ( - // Display as readable text when all items are summary_text - message.summary.map((reasoningContent, idx) => ( -
-
Summary #{idx + 1}
-
-
{reasoningContent.text}
-
-
- )) - ) : ( - // Fallback to JSON display for mixed or non-text types -
- -
- )} -
- )} - - {/* Handle encrypted reasoning content */} - {message.type === "reasoning" && message.encrypted_content && ( -
-
Encrypted Reasoning Content
-
-
{message.encrypted_content}
-
-
- )} - - {/* Handle regular content */} - {message.content && ( -
- {typeof message.content === "string" ? ( - <> -
Content
-
- {isJson(message.content) ? ( - - ) : ( -
{message.content}
- )} -
- - ) : ( - Array.isArray(message.content) && message.content.map((block, blockIndex) => renderContentBlock(block, blockIndex)) - )} -
- )} - - {/* Handle tool call specific fields */} - {(message.call_id || message.name || message.arguments) && ( -
-
Tool Details
- -
- )} - - {/* Handle additional tool-specific fields */} - {Object.keys(message).some( - (key) => !["id", "type", "status", "role", "content", "call_id", "name", "arguments", "summary", "encrypted_content"].includes(key), - ) && ( -
-
Additional Fields
- - !["id", "type", "status", "role", "content", "call_id", "name", "arguments", "summary", "encrypted_content"].includes( - key, - ), - ), - ), - null, - 2, - )} - lang="json" - readonly={true} - options={{ scrollBeyondLastLine: false, collapsibleBlocks: true, lineNumbers: "off", alwaysConsumeMouseWheel: false }} - /> -
- )} -
- ); -}; - -export default function LogResponsesMessageView({ messages }: LogResponsesMessageViewProps) { - if (!messages || messages.length === 0) { - return ( -
-
No responses messages available
-
- ); - } - - return
{messages.map((message, index) => renderMessage(message, index))}
; -} diff --git a/ui/lib/types/logs.ts b/ui/lib/types/logs.ts index b6b6ed682..b52c1414c 100644 --- a/ui/lib/types/logs.ts +++ b/ui/lib/types/logs.ts @@ -409,13 +409,13 @@ export interface ResponsesToolMessage { } // Reasoning content -export interface ResponsesReasoningContent { +export interface ResponsesReasoningSummary { type: "summary_text"; text: string; } export interface ResponsesReasoning { - summary: ResponsesReasoningContent[]; + summary: ResponsesReasoningSummary[]; encrypted_content?: string; } @@ -431,7 +431,7 @@ export interface ResponsesMessage { name?: string; arguments?: string; // Reasoning fields (merged when type is "reasoning") - summary?: ResponsesReasoningContent[]; + summary?: ResponsesReasoningSummary[]; encrypted_content?: string; // Additional tool-specific fields [key: string]: any; diff --git a/ui/package-lock.json b/ui/package-lock.json index f7e3943a4..fdd6eb97b 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -39,7 +39,7 @@ "cmdk": "1.1.1", "date-fns": "4.1.0", "lodash.isequal": "4.5.0", - "lucide-react": "0.542.0", + "lucide-react": "^0.542.0", "moment": "2.30.1", "monaco-editor": "0.52.2", "next": "15.5.8", diff --git a/ui/package.json b/ui/package.json index 694b2d0b1..0dd6de269 100644 --- a/ui/package.json +++ b/ui/package.json @@ -43,7 +43,7 @@ "cmdk": "1.1.1", "date-fns": "4.1.0", "lodash.isequal": "4.5.0", - "lucide-react": "0.542.0", + "lucide-react": "^0.542.0", "moment": "2.30.1", "monaco-editor": "0.52.2", "next": "15.5.8",