Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions core/providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,19 @@ func HandleAnthropicChatCompletionStreaming(
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
if err != nil {
return nil, err
}

response := chatResponse.ToBifrostResponsesResponse()
response.ExtraFields.RequestType = schemas.ResponsesRequest
response.ExtraFields.Provider = provider.GetProviderKey()
response.ExtraFields.ModelRequested = request.Model

return response, nil
}
return nil, err
}

Expand Down Expand Up @@ -719,6 +732,15 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke
// ResponsesStream performs a streaming responses request to the Anthropic API.
func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
return provider.ChatCompletionStream(
ctx,
postHookRunner,
key,
request.ToChatRequest(),
)
}
return nil, err
}

Expand Down
22 changes: 22 additions & 0 deletions core/providers/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,19 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
if err != nil {
return nil, err
}

response := chatResponse.ToBifrostResponsesResponse()
response.ExtraFields.RequestType = schemas.ResponsesRequest
response.ExtraFields.Provider = provider.GetProviderKey()
response.ExtraFields.ModelRequested = request.Model

return response, nil
}
return nil, err
}

Expand Down Expand Up @@ -989,6 +1002,15 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key,
// Returns a channel for streaming BifrostResponse objects or an error if the request fails.
func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
return provider.ChatCompletionStream(
ctx,
postHookRunner,
key,
request.ToChatRequest(),
)
}
return nil, err
}

Expand Down
22 changes: 22 additions & 0 deletions core/providers/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,19 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
// Check if chat completion is allowed
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
if err != nil {
return nil, err
}

response := chatResponse.ToBifrostResponsesResponse()
response.ExtraFields.RequestType = schemas.ResponsesRequest
response.ExtraFields.Provider = provider.GetProviderKey()
response.ExtraFields.ModelRequested = request.Model

return response, nil
}
return nil, err
}

Expand Down Expand Up @@ -558,6 +571,15 @@ func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key,
func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
// Check if responses stream is allowed
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
return provider.ChatCompletionStream(
ctx,
postHookRunner,
key,
request.ToChatRequest(),
)
}
return nil, err
}

Expand Down
17 changes: 17 additions & 0 deletions core/providers/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,19 @@ func HandleGeminiChatCompletionStream(
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
if err != nil {
return nil, err
}

response := chatResponse.ToBifrostResponsesResponse()
response.ExtraFields.RequestType = schemas.ResponsesRequest
response.ExtraFields.Provider = provider.GetProviderKey()
response.ExtraFields.ModelRequested = request.Model

return response, nil
}
return nil, err
}

Expand Down Expand Up @@ -550,6 +563,10 @@ func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key,
func (provider *GeminiProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
// Check if responses stream is allowed for this provider
if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
return provider.ChatCompletionStream(ctx, postHookRunner, key, request.ToChatRequest())
}
return nil, err
}

Expand Down
17 changes: 17 additions & 0 deletions core/providers/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,19 @@ func HandleOpenAIChatCompletionStreaming(
func (provider *OpenAIProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
// Check if chat completion is allowed for this provider
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
if err != nil {
return nil, err
}

response := chatResponse.ToBifrostResponsesResponse()
response.ExtraFields.RequestType = schemas.ResponsesRequest
response.ExtraFields.Provider = provider.GetProviderKey()
response.ExtraFields.ModelRequested = request.Model

return response, nil
}
return nil, err
}

Expand Down Expand Up @@ -1203,6 +1216,10 @@ func HandleOpenAIResponsesRequest(
func (provider *OpenAIProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
// Check if chat completion stream is allowed for this provider
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
return provider.ChatCompletionStream(ctx, postHookRunner, key, request.ToChatRequest())
}
return nil, err
}
var authHeader map[string]string
Expand Down
20 changes: 20 additions & 0 deletions core/providers/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -1119,3 +1119,23 @@ func GetBudgetTokensFromReasoningEffort(

return budget, nil
}

// ShouldAttemptIntegrationFallback checks if an integration fallback should be attempted.
// It returns:
// - modified context with fallback flag set (if fallback should proceed)
// - boolean indicating whether to proceed with fallback
func ShouldAttemptIntegrationFallback(ctx context.Context) (context.Context, bool) {
// Check if this is an integration request
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); !ok {
return ctx, false
}

// Check if fallback has already been attempted
if attempted, _ := ctx.Value(schemas.BifrostContextKeyIntegrationFallbackAttempted).(bool); attempted {
return ctx, false
}

// Mark fallback as attempted and return modified context
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIntegrationFallbackAttempted, true)
return ctx, true
}
26 changes: 14 additions & 12 deletions core/schemas/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,20 @@ type BifrostContextKey string

// BifrostContextKeyRequestType is a context key for the request type.
const (
BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string
BifrostContextKeyRequestID BifrostContextKey = "request-id" // string
BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string
BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct
BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost))
BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost))
BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost))
BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost)) 0 for primary, 1 for first fallback, etc.
BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost)
BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider)
BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string]string
BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string
BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string
BifrostContextKeyRequestID BifrostContextKey = "request-id" // string
BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string
BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct
BifrostContextKeyIntegrationRequest BifrostContextKey = "bifrost-integration-request" // bool (set by bifrost)
BifrostContextKeyIntegrationFallbackAttempted BifrostContextKey = "bifrost-integration-fallback-attempted" // bool (set by bifrost)
BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost))
BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost))
BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost))
BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost)) 0 for primary, 1 for first fallback, etc.
BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost)
BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider)
BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string]string
BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string
BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body"
BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool
BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool
Expand Down
3 changes: 3 additions & 0 deletions transports/bifrost-http/integrations/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle
}
}

// set context value to indicate this request is being handled by a bifrost integration
*bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyIntegrationRequest, true)

if isStreaming {
g.handleStreamingRequest(ctx, config, bifrostReq, bifrostCtx, cancel)
} else {
Expand Down