Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -2722,20 +2722,36 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
// Use the custom provider name for actual key selection, but pass base provider type for key validation
key, err = bifrost.selectKeyFromProviderForModel(&req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider)
if err != nil {
bifrost.logger.Debug("error selecting key for model %s: %v", model, err)
req.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
Provider: provider.GetProviderKey(),
ModelRequested: model,
RequestType: req.RequestType,
},
// Here if model is not required - for example file operations, or batch list operation - we don't need to throw an error
// We can pick first available key of the provider and continue
if !isModelRequired(req.RequestType) {
// Get first available key of the provider
// TODO this is temporary solution, we will fix this
// This is only for Bedrock provider, we will be adding special flag in the next release
keys, err := bifrost.account.GetKeysForProvider(&req.Context, provider.GetProviderKey())
if err != nil {
bifrost.logger.Debug("error getting keys for provider %s: %v", provider.GetProviderKey(), err)
continue
}
if len(keys) > 0 {
key = keys[0]
}
} else {
bifrost.logger.Debug("error selecting key for model %s: %v", model, err)
req.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
Provider: provider.GetProviderKey(),
ModelRequested: model,
RequestType: req.RequestType,
},
}
continue
}
continue
}
req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyID, key.ID)
req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyName, key.Name)
Expand Down
2 changes: 2 additions & 0 deletions core/internal/testutil/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ type ComprehensiveTestConfig struct {
SpeechSynthesisFallbacks []schemas.Fallback // for speech synthesis tests
EmbeddingFallbacks []schemas.Fallback // for embedding tests
SkipReason string // Reason to skip certain tests
BatchExtraParams map[string]interface{} // Extra params for batch operations (e.g., role_arn, output_s3_uri for Bedrock)
FileExtraParams map[string]interface{} // Extra params for file operations (e.g., s3_bucket for Bedrock)
}

// ComprehensiveTestAccount provides a test implementation of the Account interface for comprehensive testing.
Expand Down
54 changes: 32 additions & 22 deletions core/internal/testutil/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func RunBatchCreateTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte
},
},
CompletionWindow: "24h",
ExtraParams: testConfig.BatchExtraParams,
}

response, err := client.BatchCreateRequest(ctx, request)
Expand Down Expand Up @@ -127,6 +128,7 @@ func RunBatchRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con
},
},
CompletionWindow: "24h",
ExtraParams: testConfig.BatchExtraParams,
}

createResponse, createErr := client.BatchCreateRequest(ctx, createRequest)
Expand Down Expand Up @@ -198,6 +200,7 @@ func RunBatchCancelTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte
},
},
CompletionWindow: "24h",
ExtraParams: testConfig.BatchExtraParams,
}

createResponse, createErr := client.BatchCreateRequest(ctx, createRequest)
Expand Down Expand Up @@ -336,10 +339,11 @@ func RunFileUploadTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex
`)

request := &schemas.BifrostFileUploadRequest{
Provider: testConfig.Provider,
File: fileContent,
Filename: "test_batch.jsonl",
Purpose: "batch",
Provider: testConfig.Provider,
File: fileContent,
Filename: "test_batch.jsonl",
Purpose: "batch",
ExtraParams: testConfig.FileExtraParams,
}

response, err := client.FileUploadRequest(ctx, request)
Expand Down Expand Up @@ -378,8 +382,9 @@ func RunFileListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context,
t.Logf("[RUNNING] File List test for provider: %s", testConfig.Provider)

request := &schemas.BifrostFileListRequest{
Provider: testConfig.Provider,
Limit: 10,
Provider: testConfig.Provider,
Limit: 10,
ExtraParams: testConfig.FileExtraParams,
}

response, err := client.FileListRequest(ctx, request)
Expand Down Expand Up @@ -417,10 +422,11 @@ func RunFileRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Cont
`)

uploadRequest := &schemas.BifrostFileUploadRequest{
Provider: testConfig.Provider,
File: fileContent,
Filename: "test_retrieve.jsonl",
Purpose: "batch",
Provider: testConfig.Provider,
File: fileContent,
Filename: "test_retrieve.jsonl",
Purpose: "batch",
ExtraParams: testConfig.FileExtraParams,
}

uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest)
Expand Down Expand Up @@ -479,10 +485,11 @@ func RunFileDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex
`)

uploadRequest := &schemas.BifrostFileUploadRequest{
Provider: testConfig.Provider,
File: fileContent,
Filename: "test_delete.jsonl",
Purpose: "batch",
Provider: testConfig.Provider,
File: fileContent,
Filename: "test_delete.jsonl",
Purpose: "batch",
ExtraParams: testConfig.FileExtraParams,
}

uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest)
Expand Down Expand Up @@ -541,10 +548,11 @@ func RunFileContentTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte
`)

uploadRequest := &schemas.BifrostFileUploadRequest{
Provider: testConfig.Provider,
File: originalContent,
Filename: "test_content.jsonl",
Purpose: "batch",
Provider: testConfig.Provider,
File: originalContent,
Filename: "test_content.jsonl",
Purpose: "batch",
ExtraParams: testConfig.FileExtraParams,
}

uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest)
Expand Down Expand Up @@ -643,10 +651,11 @@ func RunFileAndBatchIntegrationTest(t *testing.T, client *bifrost.Bifrost, ctx c
`)

uploadRequest := &schemas.BifrostFileUploadRequest{
Provider: testConfig.Provider,
File: fileContent,
Filename: "integration_test_batch.jsonl",
Purpose: "batch",
Provider: testConfig.Provider,
File: fileContent,
Filename: "integration_test_batch.jsonl",
Purpose: "batch",
ExtraParams: testConfig.FileExtraParams,
}

uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest)
Expand All @@ -673,6 +682,7 @@ func RunFileAndBatchIntegrationTest(t *testing.T, client *bifrost.Bifrost, ctx c
InputFileID: uploadResponse.ID,
Endpoint: schemas.BatchEndpointChatCompletions,
CompletionWindow: "24h",
ExtraParams: testConfig.BatchExtraParams,
}

batchResponse, batchErr := client.BatchCreateRequest(ctx, batchRequest)
Expand Down
8 changes: 6 additions & 2 deletions core/providers/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -1457,8 +1457,12 @@ func (provider *BedrockProvider) FileList(ctx context.Context, key schemas.Key,
}

region := DefaultBedrockRegion
if key.BedrockKeyConfig.Region != nil {
region = *key.BedrockKeyConfig.Region
if key.BedrockKeyConfig != nil {
if key.BedrockKeyConfig.Region != nil {
region = *key.BedrockKeyConfig.Region
}
} else {
region = DefaultBedrockRegion
}

// Build S3 ListObjectsV2 request
Expand Down
26 changes: 24 additions & 2 deletions core/providers/bedrock/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@ func TestBedrock(t *testing.T) {
}
defer cancel()

// Get Bedrock-specific configuration from environment
s3Bucket := os.Getenv("AWS_S3_BUCKET")
roleArn := os.Getenv("AWS_BEDROCK_ROLE_ARN")

// Build extra params for batch and file operations
var batchExtraParams map[string]interface{}
var fileExtraParams map[string]interface{}

if s3Bucket != "" {
fileExtraParams = map[string]interface{}{
"s3_bucket": s3Bucket,
}
batchExtraParams = map[string]interface{}{
"output_s3_uri": "s3://" + s3Bucket + "/batch-output/",
}
if roleArn != "" {
batchExtraParams["role_arn"] = roleArn
}
}

testConfig := testutil.ComprehensiveTestConfig{
Provider: schemas.Bedrock,
ChatModel: "claude-4-sonnet",
Expand All @@ -50,8 +70,10 @@ func TestBedrock(t *testing.T) {
{Provider: schemas.Bedrock, Model: "claude-4-sonnet"},
{Provider: schemas.Bedrock, Model: "claude-4.5-sonnet"},
},
EmbeddingModel: "cohere.embed-v4:0",
ReasoningModel: "claude-4.5-sonnet",
EmbeddingModel: "cohere.embed-v4:0",
ReasoningModel: "claude-4.5-sonnet",
BatchExtraParams: batchExtraParams,
FileExtraParams: fileExtraParams,
Scenarios: testutil.TestScenarios{
TextCompletion: false, // Not supported
SimpleChat: true,
Expand Down
6 changes: 3 additions & 3 deletions core/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ var rateLimitPatterns = []string{
"concurrent requests limit",
}

// IsModelRequired returns true if the request type requires a model
func IsModelRequired(reqType schemas.RequestType) bool {
// isModelRequired returns true if the request type requires a model
func isModelRequired(reqType schemas.RequestType) bool {
return reqType == schemas.TextCompletionRequest || reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.EmbeddingRequest
}

Expand Down Expand Up @@ -97,7 +97,7 @@ func validateRequest(req *schemas.BifrostRequest) *schemas.BifrostError {
if provider == "" {
return newBifrostErrorFromMsg("provider is required")
}
if IsModelRequired(req.RequestType) && model == "" {
if isModelRequired(req.RequestType) && model == "" {
return newBifrostErrorFromMsg("model is required")
}

Expand Down
3 changes: 2 additions & 1 deletion tests/integrations/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@
"bedrock_key_config": {
"access_key": "env.AWS_ACCESS_KEY_ID",
"secret_key": "env.AWS_SECRET_ACCESS_KEY",
"region": "env.AWS_REGION"
"region": "env.AWS_REGION",
"arn": "env.AWS_BEDROCK_ROLE_ARN"
},
"weight": 1
}
Expand Down