Skip to content

Commit dcf76b0

Browse files
committed
fix: reroute requests incoming from integrations for full compatibility in custom providers
1 parent 315d945 commit dcf76b0

File tree

8 files changed

+228
-17
lines changed

8 files changed

+228
-17
lines changed

core/providers/anthropic/anthropic.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,19 @@ func (provider *AnthropicProvider) TextCompletionStream(ctx context.Context, pos
309309
// Returns a BifrostResponse containing the completion results or an error if the request fails.
310310
func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
311311
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
312+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
313+
responsesResponse, err := provider.Responses(ctx, key, request.ToResponsesRequest())
314+
if err != nil {
315+
return nil, err
316+
}
317+
318+
response := responsesResponse.ToBifrostChatResponse()
319+
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
320+
response.ExtraFields.Provider = provider.GetProviderKey()
321+
response.ExtraFields.ModelRequested = request.Model
322+
323+
return response, nil
324+
}
312325
return nil, err
313326
}
314327

@@ -363,6 +376,10 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem
363376
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
364377
func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
365378
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
379+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
380+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsChatCompletionToResponsesStreamFallback, true)
381+
return provider.ResponsesStream(ctx, postHookRunner, key, request.ToResponsesRequest())
382+
}
366383
return nil, err
367384
}
368385

@@ -671,6 +688,19 @@ func HandleAnthropicChatCompletionStreaming(
671688
// Returns a BifrostResponse containing the completion results or an error if the request fails.
672689
func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
673690
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
691+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
692+
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
693+
if err != nil {
694+
return nil, err
695+
}
696+
697+
response := chatResponse.ToBifrostResponsesResponse()
698+
response.ExtraFields.RequestType = schemas.ResponsesRequest
699+
response.ExtraFields.Provider = provider.GetProviderKey()
700+
response.ExtraFields.ModelRequested = request.Model
701+
702+
return response, nil
703+
}
674704
return nil, err
675705
}
676706

@@ -724,6 +754,15 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke
724754
// ResponsesStream performs a streaming responses request to the Anthropic API.
725755
func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
726756
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
757+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
758+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
759+
return provider.ChatCompletionStream(
760+
ctx,
761+
postHookRunner,
762+
key,
763+
request.ToChatRequest(),
764+
)
765+
}
727766
return nil, err
728767
}
729768

core/providers/bedrock/bedrock.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,19 @@ func (provider *BedrockProvider) TextCompletionStream(ctx context.Context, postH
677677
// Returns a BifrostResponse containing the completion results or an error if the request fails.
678678
func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
679679
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
680+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
681+
responsesResponse, err := provider.Responses(ctx, key, request.ToResponsesRequest())
682+
if err != nil {
683+
return nil, err
684+
}
685+
686+
response := responsesResponse.ToBifrostChatResponse()
687+
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
688+
response.ExtraFields.Provider = provider.GetProviderKey()
689+
response.ExtraFields.ModelRequested = request.Model
690+
691+
return response, nil
692+
}
680693
return nil, err
681694
}
682695

@@ -748,6 +761,10 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas
748761
// Returns a channel for streaming BifrostResponse objects or an error if the request fails.
749762
func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
750763
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
764+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
765+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsChatCompletionToResponsesStreamFallback, true)
766+
return provider.ResponsesStream(ctx, postHookRunner, key, request.ToResponsesRequest())
767+
}
751768
return nil, err
752769
}
753770

@@ -916,6 +933,19 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
916933
// Returns a BifrostResponse containing the completion results or an error if the request fails.
917934
func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
918935
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
936+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
937+
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
938+
if err != nil {
939+
return nil, err
940+
}
941+
942+
response := chatResponse.ToBifrostResponsesResponse()
943+
response.ExtraFields.RequestType = schemas.ResponsesRequest
944+
response.ExtraFields.Provider = provider.GetProviderKey()
945+
response.ExtraFields.ModelRequested = request.Model
946+
947+
return response, nil
948+
}
919949
return nil, err
920950
}
921951

@@ -989,6 +1019,15 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key,
9891019
// Returns a channel for streaming BifrostResponse objects or an error if the request fails.
9901020
func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
9911021
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
1022+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
1023+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
1024+
return provider.ChatCompletionStream(
1025+
ctx,
1026+
postHookRunner,
1027+
key,
1028+
request.ToChatRequest(),
1029+
)
1030+
}
9921031
return nil, err
9931032
}
9941033

core/providers/cohere/cohere.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,19 @@ func (provider *CohereProvider) TextCompletionStream(ctx context.Context, postHo
280280
func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
281281
// Check if chat completion is allowed
282282
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
283+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
284+
responsesResponse, err := provider.Responses(ctx, key, request.ToResponsesRequest())
285+
if err != nil {
286+
return nil, err
287+
}
288+
289+
response := responsesResponse.ToBifrostChatResponse()
290+
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
291+
response.ExtraFields.Provider = provider.GetProviderKey()
292+
response.ExtraFields.ModelRequested = request.Model
293+
294+
return response, nil
295+
}
283296
return nil, err
284297
}
285298

@@ -334,6 +347,10 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas.
334347
func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
335348
// Check if chat completion stream is allowed
336349
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
350+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
351+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsChatCompletionToResponsesStreamFallback, true)
352+
return provider.ResponsesStream(ctx, postHookRunner, key, request.ToResponsesRequest())
353+
}
337354
return nil, err
338355
}
339356

@@ -506,6 +523,19 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
506523
func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
507524
// Check if chat completion is allowed
508525
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
526+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
527+
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
528+
if err != nil {
529+
return nil, err
530+
}
531+
532+
response := chatResponse.ToBifrostResponsesResponse()
533+
response.ExtraFields.RequestType = schemas.ResponsesRequest
534+
response.ExtraFields.Provider = provider.GetProviderKey()
535+
response.ExtraFields.ModelRequested = request.Model
536+
537+
return response, nil
538+
}
509539
return nil, err
510540
}
511541

@@ -560,6 +590,15 @@ func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key,
560590
func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
561591
// Check if responses stream is allowed
562592
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
593+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
594+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
595+
return provider.ChatCompletionStream(
596+
ctx,
597+
postHookRunner,
598+
key,
599+
request.ToChatRequest(),
600+
)
601+
}
563602
return nil, err
564603
}
565604

core/providers/gemini/gemini.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,19 @@ func (provider *GeminiProvider) TextCompletionStream(ctx context.Context, postHo
217217
func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
218218
// Check if chat completion is allowed for this provider
219219
if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
220+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
221+
responsesResponse, err := provider.Responses(ctx, key, request.ToResponsesRequest())
222+
if err != nil {
223+
return nil, err
224+
}
225+
226+
response := responsesResponse.ToBifrostChatResponse()
227+
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
228+
response.ExtraFields.Provider = provider.GetProviderKey()
229+
response.ExtraFields.ModelRequested = request.Model
230+
231+
return response, nil
232+
}
220233
return nil, err
221234
}
222235

@@ -256,6 +269,10 @@ func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas.
256269
func (provider *GeminiProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
257270
// Check if chat completion stream is allowed for this provider
258271
if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
272+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
273+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsChatCompletionToResponsesStreamFallback, true)
274+
return provider.ResponsesStream(ctx, postHookRunner, key, request.ToResponsesRequest())
275+
}
259276
return nil, err
260277
}
261278

@@ -496,6 +513,19 @@ func HandleGeminiChatCompletionStream(
496513
// Returns a BifrostResponse containing the completion results or an error if the request fails.
497514
func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
498515
if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
516+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
517+
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
518+
if err != nil {
519+
return nil, err
520+
}
521+
522+
response := chatResponse.ToBifrostResponsesResponse()
523+
response.ExtraFields.RequestType = schemas.ResponsesRequest
524+
response.ExtraFields.Provider = provider.GetProviderKey()
525+
response.ExtraFields.ModelRequested = request.Model
526+
527+
return response, nil
528+
}
499529
return nil, err
500530
}
501531

@@ -542,6 +572,10 @@ func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key,
542572
func (provider *GeminiProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
543573
// Check if responses stream is allowed for this provider
544574
if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
575+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
576+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
577+
return provider.ChatCompletionStream(ctx, postHookRunner, key, request.ToChatRequest())
578+
}
545579
return nil, err
546580
}
547581

core/providers/openai/openai.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,19 @@ func HandleOpenAITextCompletionStreaming(
612612
func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
613613
// Check if chat completion is allowed for this provider
614614
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
615+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
616+
responsesResponse, err := provider.Responses(ctx, key, request.ToResponsesRequest())
617+
if err != nil {
618+
return nil, err
619+
}
620+
621+
response := responsesResponse.ToBifrostChatResponse()
622+
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
623+
response.ExtraFields.Provider = provider.GetProviderKey()
624+
response.ExtraFields.ModelRequested = request.Model
625+
626+
return response, nil
627+
}
615628
return nil, err
616629
}
617630

@@ -719,6 +732,10 @@ func HandleOpenAIChatCompletionRequest(
719732
func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
720733
// Check if chat completion stream is allowed for this provider
721734
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
735+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
736+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsChatCompletionToResponsesStreamFallback, true)
737+
return provider.ResponsesStream(ctx, postHookRunner, key, request.ToResponsesRequest())
738+
}
722739
return nil, err
723740
}
724741
var authHeader map[string]string
@@ -1097,6 +1114,19 @@ func HandleOpenAIChatCompletionStreaming(
10971114
func (provider *OpenAIProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
10981115
// Check if chat completion is allowed for this provider
10991116
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
1117+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
1118+
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
1119+
if err != nil {
1120+
return nil, err
1121+
}
1122+
1123+
response := chatResponse.ToBifrostResponsesResponse()
1124+
response.ExtraFields.RequestType = schemas.ResponsesRequest
1125+
response.ExtraFields.Provider = provider.GetProviderKey()
1126+
response.ExtraFields.ModelRequested = request.Model
1127+
1128+
return response, nil
1129+
}
11001130
return nil, err
11011131
}
11021132

@@ -1203,6 +1233,10 @@ func HandleOpenAIResponsesRequest(
12031233
func (provider *OpenAIProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
12041234
// Check if chat completion stream is allowed for this provider
12051235
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
1236+
if ctx, shouldFallback := providerUtils.ShouldAttemptIntegrationFallback(ctx); shouldFallback {
1237+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
1238+
return provider.ChatCompletionStream(ctx, postHookRunner, key, request.ToChatRequest())
1239+
}
12061240
return nil, err
12071241
}
12081242
var authHeader map[string]string

core/providers/utils/utils.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,3 +1024,23 @@ func GetRandomString(length int) string {
10241024
}
10251025
return string(b)
10261026
}
1027+
1028+
// ShouldAttemptIntegrationFallback checks if an integration fallback should be attempted.
1029+
// It returns:
1030+
// - modified context with fallback flag set (if fallback should proceed)
1031+
// - boolean indicating whether to proceed with fallback
1032+
func ShouldAttemptIntegrationFallback(ctx context.Context) (context.Context, bool) {
1033+
// Check if this is an integration request
1034+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); !ok {
1035+
return ctx, false
1036+
}
1037+
1038+
// Check if fallback has already been attempted
1039+
if attempted, _ := ctx.Value(schemas.BifrostContextKeyIntegrationFallbackAttempted).(bool); attempted {
1040+
return ctx, false
1041+
}
1042+
1043+
// Mark fallback as attempted and return modified context
1044+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIntegrationFallbackAttempted, true)
1045+
return ctx, true
1046+
}

0 commit comments

Comments
 (0)