diff --git a/core/bifrost.go b/core/bifrost.go index 2ec2d3061..776a15d87 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -2722,20 +2722,36 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas // Use the custom provider name for actual key selection, but pass base provider type for key validation key, err = bifrost.selectKeyFromProviderForModel(&req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) if err != nil { - bifrost.logger.Debug("error selecting key for model %s: %v", model, err) - req.Err <- schemas.BifrostError{ - IsBifrostError: false, - Error: &schemas.ErrorField{ - Message: err.Error(), - Error: err, - }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, - }, + // Here if model is not required - for example file operations, or batch list operation - we don't need to throw an error + // We can pick first available key of the provider and continue + if !isModelRequired(req.RequestType) { + // Get first available key of the provider + // TODO this is temporary solution, we will fix this + // This is only for Bedrock provider, we will be adding special flag in the next release + keys, err := bifrost.account.GetKeysForProvider(&req.Context, provider.GetProviderKey()) + if err != nil { + bifrost.logger.Debug("error getting keys for provider %s: %v", provider.GetProviderKey(), err) + continue + } + if len(keys) > 0 { + key = keys[0] + } + } else { + bifrost.logger.Debug("error selecting key for model %s: %v", model, err) + req.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: provider.GetProviderKey(), + ModelRequested: model, + RequestType: req.RequestType, + }, + } + continue } - continue } req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyID, key.ID) req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyName, key.Name) diff --git a/core/internal/testutil/account.go b/core/internal/testutil/account.go index 3c53d84d4..58a0ef7d5 100644 --- a/core/internal/testutil/account.go +++ b/core/internal/testutil/account.go @@ -72,6 +72,8 @@ type ComprehensiveTestConfig struct { SpeechSynthesisFallbacks []schemas.Fallback // for speech synthesis tests EmbeddingFallbacks []schemas.Fallback // for embedding tests SkipReason string // Reason to skip certain tests + BatchExtraParams map[string]interface{} // Extra params for batch operations (e.g., role_arn, output_s3_uri for Bedrock) + FileExtraParams map[string]interface{} // Extra params for file operations (e.g., s3_bucket for Bedrock) } // ComprehensiveTestAccount provides a test implementation of the Account interface for comprehensive testing. diff --git a/core/internal/testutil/batch.go b/core/internal/testutil/batch.go index fe28cb253..e468bd698 100644 --- a/core/internal/testutil/batch.go +++ b/core/internal/testutil/batch.go @@ -36,6 +36,7 @@ func RunBatchCreateTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte }, }, CompletionWindow: "24h", + ExtraParams: testConfig.BatchExtraParams, } response, err := client.BatchCreateRequest(ctx, request) @@ -127,6 +128,7 @@ func RunBatchRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con }, }, CompletionWindow: "24h", + ExtraParams: testConfig.BatchExtraParams, } createResponse, createErr := client.BatchCreateRequest(ctx, createRequest) @@ -198,6 +200,7 @@ func RunBatchCancelTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte }, }, CompletionWindow: "24h", + ExtraParams: testConfig.BatchExtraParams, } createResponse, createErr := client.BatchCreateRequest(ctx, createRequest) @@ -336,10 +339,11 @@ func RunFileUploadTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex `) request := &schemas.BifrostFileUploadRequest{ - Provider: testConfig.Provider, - File: fileContent, - Filename: "test_batch.jsonl", - Purpose: "batch", + Provider: testConfig.Provider, + File: fileContent, + Filename: "test_batch.jsonl", + Purpose: "batch", + ExtraParams: testConfig.FileExtraParams, } response, err := client.FileUploadRequest(ctx, request) @@ -378,8 +382,9 @@ func RunFileListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, t.Logf("[RUNNING] File List test for provider: %s", testConfig.Provider) request := &schemas.BifrostFileListRequest{ - Provider: testConfig.Provider, - Limit: 10, + Provider: testConfig.Provider, + Limit: 10, + ExtraParams: testConfig.FileExtraParams, } response, err := client.FileListRequest(ctx, request) @@ -417,10 +422,11 @@ func RunFileRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Cont `) uploadRequest := &schemas.BifrostFileUploadRequest{ - Provider: testConfig.Provider, - File: fileContent, - Filename: "test_retrieve.jsonl", - Purpose: "batch", + Provider: testConfig.Provider, + File: fileContent, + Filename: "test_retrieve.jsonl", + Purpose: "batch", + ExtraParams: testConfig.FileExtraParams, } uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest) @@ -479,10 +485,11 @@ func RunFileDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex `) uploadRequest := &schemas.BifrostFileUploadRequest{ - Provider: testConfig.Provider, - File: fileContent, - Filename: "test_delete.jsonl", - Purpose: "batch", + Provider: testConfig.Provider, + File: fileContent, + Filename: "test_delete.jsonl", + Purpose: "batch", + ExtraParams: testConfig.FileExtraParams, } uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest) @@ -541,10 +548,11 @@ func RunFileContentTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte `) uploadRequest := &schemas.BifrostFileUploadRequest{ - Provider: testConfig.Provider, - File: originalContent, - Filename: "test_content.jsonl", - Purpose: "batch", + Provider: testConfig.Provider, + File: originalContent, + Filename: "test_content.jsonl", + Purpose: "batch", + ExtraParams: testConfig.FileExtraParams, } uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest) @@ -643,10 +651,11 @@ func RunFileAndBatchIntegrationTest(t *testing.T, client *bifrost.Bifrost, ctx c `) uploadRequest := &schemas.BifrostFileUploadRequest{ - Provider: testConfig.Provider, - File: fileContent, - Filename: "integration_test_batch.jsonl", - Purpose: "batch", + Provider: testConfig.Provider, + File: fileContent, + Filename: "integration_test_batch.jsonl", + Purpose: "batch", + ExtraParams: testConfig.FileExtraParams, } uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest) @@ -673,6 +682,7 @@ func RunFileAndBatchIntegrationTest(t *testing.T, client *bifrost.Bifrost, ctx c InputFileID: uploadResponse.ID, Endpoint: schemas.BatchEndpointChatCompletions, CompletionWindow: "24h", + ExtraParams: testConfig.BatchExtraParams, } batchResponse, batchErr := client.BatchCreateRequest(ctx, batchRequest) diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 4f056553f..fdbf586f0 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -1457,8 +1457,12 @@ func (provider *BedrockProvider) FileList(ctx context.Context, key schemas.Key, } region := DefaultBedrockRegion - if key.BedrockKeyConfig.Region != nil { - region = *key.BedrockKeyConfig.Region + if key.BedrockKeyConfig != nil { + if key.BedrockKeyConfig.Region != nil { + region = *key.BedrockKeyConfig.Region + } + } else { + region = DefaultBedrockRegion } // Build S3 ListObjectsV2 request diff --git a/core/providers/bedrock/bedrock_test.go b/core/providers/bedrock/bedrock_test.go index 31c3fc227..d8cc3d798 100644 --- a/core/providers/bedrock/bedrock_test.go +++ b/core/providers/bedrock/bedrock_test.go @@ -42,6 +42,26 @@ func TestBedrock(t *testing.T) { } defer cancel() + // Get Bedrock-specific configuration from environment + s3Bucket := os.Getenv("AWS_S3_BUCKET") + roleArn := os.Getenv("AWS_BEDROCK_ROLE_ARN") + + // Build extra params for batch and file operations + var batchExtraParams map[string]interface{} + var fileExtraParams map[string]interface{} + + if s3Bucket != "" { + fileExtraParams = map[string]interface{}{ + "s3_bucket": s3Bucket, + } + batchExtraParams = map[string]interface{}{ + "output_s3_uri": "s3://" + s3Bucket + "/batch-output/", + } + if roleArn != "" { + batchExtraParams["role_arn"] = roleArn + } + } + testConfig := testutil.ComprehensiveTestConfig{ Provider: schemas.Bedrock, ChatModel: "claude-4-sonnet", @@ -50,8 +70,10 @@ func TestBedrock(t *testing.T) { {Provider: schemas.Bedrock, Model: "claude-4-sonnet"}, {Provider: schemas.Bedrock, Model: "claude-4.5-sonnet"}, }, - EmbeddingModel: "cohere.embed-v4:0", - ReasoningModel: "claude-4.5-sonnet", + EmbeddingModel: "cohere.embed-v4:0", + ReasoningModel: "claude-4.5-sonnet", + BatchExtraParams: batchExtraParams, + FileExtraParams: fileExtraParams, Scenarios: testutil.TestScenarios{ TextCompletion: false, // Not supported SimpleChat: true, diff --git a/core/utils.go b/core/utils.go index 27f511b1a..352d776cd 100644 --- a/core/utils.go +++ b/core/utils.go @@ -47,8 +47,8 @@ var rateLimitPatterns = []string{ "concurrent requests limit", } -// IsModelRequired returns true if the request type requires a model -func IsModelRequired(reqType schemas.RequestType) bool { +// isModelRequired returns true if the request type requires a model +func isModelRequired(reqType schemas.RequestType) bool { return reqType == schemas.TextCompletionRequest || reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.EmbeddingRequest } @@ -97,7 +97,7 @@ func validateRequest(req *schemas.BifrostRequest) *schemas.BifrostError { if provider == "" { return newBifrostErrorFromMsg("provider is required") } - if IsModelRequired(req.RequestType) && model == "" { + if isModelRequired(req.RequestType) && model == "" { return newBifrostErrorFromMsg("model is required") } diff --git a/tests/integrations/config.json b/tests/integrations/config.json index a12220579..cdf6dbd59 100644 --- a/tests/integrations/config.json +++ b/tests/integrations/config.json @@ -135,7 +135,8 @@ "bedrock_key_config": { "access_key": "env.AWS_ACCESS_KEY_ID", "secret_key": "env.AWS_SECRET_ACCESS_KEY", - "region": "env.AWS_REGION" + "region": "env.AWS_REGION", + "arn": "env.AWS_BEDROCK_ROLE_ARN" }, "weight": 1 }