Skip to content

Commit 8dca6ec

Browse files
committed
fix: reroute requests incoming from integrations for full compatibility in custom providers
1 parent bccaa4b commit 8dca6ec

File tree

6 files changed

+176
-15
lines changed

6 files changed

+176
-15
lines changed

core/providers/anthropic/anthropic.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,19 @@ func (provider *AnthropicProvider) TextCompletionStream(ctx context.Context, pos
308308
// Returns a BifrostResponse containing the completion results or an error if the request fails.
309309
func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
310310
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
311+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
312+
responsesResponse, err := provider.Responses(ctx, key, request.ToResponsesRequest())
313+
if err != nil {
314+
return nil, err
315+
}
316+
317+
response := responsesResponse.ToBifrostChatResponse()
318+
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
319+
response.ExtraFields.Provider = provider.GetProviderKey()
320+
response.ExtraFields.ModelRequested = request.Model
321+
322+
return response, nil
323+
}
311324
return nil, err
312325
}
313326

@@ -358,6 +371,10 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem
358371
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
359372
func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
360373
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
374+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
375+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsChatCompletionToResponsesStreamFallback, true)
376+
return provider.ResponsesStream(ctx, postHookRunner, key, request.ToResponsesRequest())
377+
}
361378
return nil, err
362379
}
363380

@@ -659,6 +676,19 @@ func HandleAnthropicChatCompletionStreaming(
659676
// Returns a BifrostResponse containing the completion results or an error if the request fails.
660677
func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
661678
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
679+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
680+
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
681+
if err != nil {
682+
return nil, err
683+
}
684+
685+
response := chatResponse.ToBifrostResponsesResponse()
686+
response.ExtraFields.RequestType = schemas.ResponsesRequest
687+
response.ExtraFields.Provider = provider.GetProviderKey()
688+
response.ExtraFields.ModelRequested = request.Model
689+
690+
return response, nil
691+
}
662692
return nil, err
663693
}
664694

@@ -707,6 +737,15 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke
707737
// ResponsesStream performs a streaming responses request to the Anthropic API.
708738
func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
709739
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
740+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
741+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
742+
return provider.ChatCompletionStream(
743+
ctx,
744+
postHookRunner,
745+
key,
746+
request.ToChatRequest(),
747+
)
748+
}
710749
return nil, err
711750
}
712751

core/providers/bedrock/bedrock.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,19 @@ func (provider *BedrockProvider) TextCompletionStream(ctx context.Context, postH
663663
// Returns a BifrostResponse containing the completion results or an error if the request fails.
664664
func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
665665
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
666+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
667+
responsesResponse, err := provider.Responses(ctx, key, request.ToResponsesRequest())
668+
if err != nil {
669+
return nil, err
670+
}
671+
672+
response := responsesResponse.ToBifrostChatResponse()
673+
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
674+
response.ExtraFields.Provider = provider.GetProviderKey()
675+
response.ExtraFields.ModelRequested = request.Model
676+
677+
return response, nil
678+
}
666679
return nil, err
667680
}
668681

@@ -729,6 +742,10 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas
729742
// Returns a channel for streaming BifrostResponse objects or an error if the request fails.
730743
func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
731744
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
745+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
746+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsChatCompletionToResponsesStreamFallback, true)
747+
return provider.ResponsesStream(ctx, postHookRunner, key, request.ToResponsesRequest())
748+
}
732749
return nil, err
733750
}
734751

@@ -891,6 +908,19 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
891908
// Returns a BifrostResponse containing the completion results or an error if the request fails.
892909
func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
893910
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
911+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
912+
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
913+
if err != nil {
914+
return nil, err
915+
}
916+
917+
response := chatResponse.ToBifrostResponsesResponse()
918+
response.ExtraFields.RequestType = schemas.ResponsesRequest
919+
response.ExtraFields.Provider = provider.GetProviderKey()
920+
response.ExtraFields.ModelRequested = request.Model
921+
922+
return response, nil
923+
}
894924
return nil, err
895925
}
896926

@@ -959,6 +989,15 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key,
959989
// Returns a channel for streaming BifrostResponse objects or an error if the request fails.
960990
func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
961991
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
992+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
993+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
994+
return provider.ChatCompletionStream(
995+
ctx,
996+
postHookRunner,
997+
key,
998+
request.ToChatRequest(),
999+
)
1000+
}
9621001
return nil, err
9631002
}
9641003

core/providers/cohere/cohere.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,19 @@ func (provider *CohereProvider) TextCompletionStream(ctx context.Context, postHo
288288
func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
289289
// Check if chat completion is allowed
290290
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
291+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
292+
responsesResponse, err := provider.Responses(ctx, key, request.ToResponsesRequest())
293+
if err != nil {
294+
return nil, err
295+
}
296+
297+
response := responsesResponse.ToBifrostChatResponse()
298+
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
299+
response.ExtraFields.Provider = provider.GetProviderKey()
300+
response.ExtraFields.ModelRequested = request.Model
301+
302+
return response, nil
303+
}
291304
return nil, err
292305
}
293306

@@ -337,6 +350,10 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas.
337350
func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
338351
// Check if chat completion stream is allowed
339352
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
353+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
354+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsChatCompletionToResponsesStreamFallback, true)
355+
return provider.ResponsesStream(ctx, postHookRunner, key, request.ToResponsesRequest())
356+
}
340357
return nil, err
341358
}
342359

@@ -504,6 +521,19 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
504521
func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
505522
// Check if chat completion is allowed
506523
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
524+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
525+
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
526+
if err != nil {
527+
return nil, err
528+
}
529+
530+
response := chatResponse.ToBifrostResponsesResponse()
531+
response.ExtraFields.RequestType = schemas.ResponsesRequest
532+
response.ExtraFields.Provider = provider.GetProviderKey()
533+
response.ExtraFields.ModelRequested = request.Model
534+
535+
return response, nil
536+
}
507537
return nil, err
508538
}
509539

@@ -553,6 +583,15 @@ func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key,
553583
func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
554584
// Check if responses stream is allowed
555585
if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
586+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
587+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
588+
return provider.ChatCompletionStream(
589+
ctx,
590+
postHookRunner,
591+
key,
592+
request.ToChatRequest(),
593+
)
594+
}
556595
return nil, err
557596
}
558597

core/providers/openai/openai.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,19 @@ func HandleOpenAITextCompletionStreaming(
580580
func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
581581
// Check if chat completion is allowed for this provider
582582
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
583+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
584+
responsesResponse, err := provider.Responses(ctx, key, request.ToResponsesRequest())
585+
if err != nil {
586+
return nil, err
587+
}
588+
589+
response := responsesResponse.ToBifrostChatResponse()
590+
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
591+
response.ExtraFields.Provider = provider.GetProviderKey()
592+
response.ExtraFields.ModelRequested = request.Model
593+
594+
return response, nil
595+
}
583596
return nil, err
584597
}
585598

@@ -680,6 +693,10 @@ func HandleOpenAIChatCompletionRequest(
680693
func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
681694
// Check if chat completion stream is allowed for this provider
682695
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
696+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
697+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsChatCompletionToResponsesStreamFallback, true)
698+
return provider.ResponsesStream(ctx, postHookRunner, key, request.ToResponsesRequest())
699+
}
683700
return nil, err
684701
}
685702
var authHeader map[string]string
@@ -1042,6 +1059,19 @@ func HandleOpenAIChatCompletionStreaming(
10421059
func (provider *OpenAIProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
10431060
// Check if chat completion is allowed for this provider
10441061
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
1062+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
1063+
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
1064+
if err != nil {
1065+
return nil, err
1066+
}
1067+
1068+
response := chatResponse.ToBifrostResponsesResponse()
1069+
response.ExtraFields.RequestType = schemas.ResponsesRequest
1070+
response.ExtraFields.Provider = provider.GetProviderKey()
1071+
response.ExtraFields.ModelRequested = request.Model
1072+
1073+
return response, nil
1074+
}
10451075
return nil, err
10461076
}
10471077

@@ -1141,6 +1171,15 @@ func HandleOpenAIResponsesRequest(
11411171
func (provider *OpenAIProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
11421172
// Check if chat completion stream is allowed for this provider
11431173
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
1174+
if _, ok := ctx.Value(schemas.BifrostContextKeyIntegrationRequest).(bool); ok {
1175+
ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
1176+
return provider.ChatCompletionStream(
1177+
ctx,
1178+
postHookRunner,
1179+
key,
1180+
request.ToChatRequest(),
1181+
)
1182+
}
11441183
return nil, err
11451184
}
11461185
var authHeader map[string]string

core/schemas/bifrost.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,23 @@ type BifrostContextKey string
102102

103103
// BifrostContextKeyRequestType is a context key for the request type.
104104
const (
105-
BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string
106-
BifrostContextKeyRequestID BifrostContextKey = "request-id" // string
107-
BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string
108-
BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct
109-
BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost))
110-
BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost))
111-
BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost))
112-
BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost)) 0 for primary, 1 for first fallback, etc.
113-
BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost)
114-
BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider)
115-
BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string]string
116-
BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string
117-
BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" // bool
118-
BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool
119-
BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost)
105+
BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string
106+
BifrostContextKeyRequestID BifrostContextKey = "request-id" // string
107+
BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string
108+
BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct
109+
BifrostContextKeyIntegrationRequest BifrostContextKey = "bifrost-integration-request" // bool (set by bifrost)
110+
BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost))
111+
BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost))
112+
BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost))
113+
BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost)) 0 for primary, 1 for first fallback, etc.
114+
BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost)
115+
BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider)
116+
BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string]string
117+
BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string
118+
BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" // bool
119+
BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool
120+
BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost)
121+
BifrostContextKeyIsChatCompletionToResponsesStreamFallback BifrostContextKey = "bifrost-is-chat-completion-to-responses-stream-fallback" // bool (set by bifrost)
120122
)
121123

122124
// NOTE: for custom plugin implementation dealing with streaming short circuit,

0 commit comments

Comments
 (0)