diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index aa1dfa1fd..7ce0399c8 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -671,6 +671,19 @@ func HandleAnthropicChatCompletionStreaming( // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) + if err != nil { + return nil, err + } + + response := chatResponse.ToBifrostResponsesResponse() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model + + return response, nil + } return nil, err } @@ -719,6 +732,15 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke // ResponsesStream performs a streaming responses request to the Anthropic API. func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) + return provider.ChatCompletionStream( + ctx, + postHookRunner, + key, + request.ToChatRequest(), + ) + } return nil, err } diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 9dbcb8a12..1baef5b5b 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -916,6 +916,19 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) + if err != nil { + return nil, err + } + + response := chatResponse.ToBifrostResponsesResponse() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model + + return response, nil + } return nil, err } @@ -989,6 +1002,15 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, // Returns a channel for streaming BifrostResponse objects or an error if the request fails. func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) + return provider.ChatCompletionStream( + ctx, + postHookRunner, + key, + request.ToChatRequest(), + ) + } return nil, err } diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index f5171971a..7c73af588 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -504,6 +504,19 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { // Check if chat completion is allowed if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) + if err != nil { + return nil, err + } + + response := chatResponse.ToBifrostResponsesResponse() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model + + return response, nil + } return nil, err } @@ -558,6 +571,15 @@ func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if responses stream is allowed if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) + return provider.ChatCompletionStream( + ctx, + postHookRunner, + key, + request.ToChatRequest(), + ) + } return nil, err } diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index f18ab271a..d51c915b6 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -499,6 +499,19 @@ func HandleGeminiChatCompletionStream( // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) + if err != nil { + return nil, err + } + + response := chatResponse.ToBifrostResponsesResponse() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model + + return response, nil + } return nil, err } @@ -550,6 +563,10 @@ func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key, func (provider *GeminiProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if responses stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) + return provider.ChatCompletionStream(ctx, postHookRunner, key, request.ToChatRequest()) + } return nil, err } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 5e40db546..2efce61c1 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -1097,6 +1097,19 @@ func HandleOpenAIChatCompletionStreaming( func (provider *OpenAIProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { // Check if chat completion is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) + if err != nil { + return nil, err + } + + response := chatResponse.ToBifrostResponsesResponse() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model + + return response, nil + } return nil, err } @@ -1203,6 +1216,10 @@ func HandleOpenAIResponsesRequest( func (provider *OpenAIProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { + if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback { + ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) + return provider.ChatCompletionStream(ctx, postHookRunner, key, request.ToChatRequest()) + } return nil, err } var authHeader map[string]string diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index ba86f5cb6..c947a048e 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -1119,3 +1119,23 @@ func GetBudgetTokensFromReasoningEffort( return budget, nil } + +// ShouldAttemptIntegrationFallback checks if an integration fallback should be attempted. +// It returns: +// - modified context with fallback flag set (if fallback should proceed) +// - boolean indicating whether to proceed with fallback +func ShouldAttemptIntegrationFallback(ctx context.Context) (context.Context, bool) { + // Check if this is an integration request + if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); !ok { + return ctx, false + } + + // Check if fallback has already been attempted + if attempted, _ := ctx.Value(schemas.BifrostContextKeyIntegrationFallbackAttempted).(bool); attempted { + return ctx, false + } + + // Mark fallback as attempted and return modified context + ctx = context.WithValue(ctx, schemas.BifrostContextKeyIntegrationFallbackAttempted, true) + return ctx, true +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 3c65a61e9..61dd18970 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -102,18 +102,20 @@ type BifrostContextKey string // BifrostContextKeyRequestType is a context key for the request type. const ( - BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string - BifrostContextKeyRequestID BifrostContextKey = "request-id" // string - BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string - BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct - BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost)) - BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost)) - BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost)) - BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost)) 0 for primary, 1 for first fallback, etc. - BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost) - BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) - BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string]string - BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string + BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string + BifrostContextKeyRequestID BifrostContextKey = "request-id" // string + BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string + BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct + BifrostContextKeyIntegrationRequest BifrostContextKey = "bifrost-integration-request" // bool (set by bifrost) + BifrostContextKeyIntegrationFallbackAttempted BifrostContextKey = "bifrost-integration-fallback-attempted" // bool (set by bifrost) + BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost)) + BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost)) + BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost)) + BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost)) 0 for primary, 1 for first fallback, etc. + BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost) + BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) + BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string]string + BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 974ff94c3..8961d8edb 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -375,6 +375,9 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle } } + // set context value to indicate this request is being handled by a bifrost integration + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyIntegrationRequest, true) + if isStreaming { g.handleStreamingRequest(ctx, config, bifrostReq, bifrostCtx, cancel) } else {