diff --git a/core/internal/testutil/automatic_function_calling.go b/core/internal/testutil/automatic_function_calling.go index 93b9711a5..d4c64d1a2 100644 --- a/core/internal/testutil/automatic_function_calling.go +++ b/core/internal/testutil/automatic_function_calling.go @@ -162,7 +162,7 @@ func validateAutomaticToolCall(t *testing.T, toolCalls []ToolCallInfo, apiName s // Validation for tool call already happened inside WithDualAPITestRetry // If we reach here, the tool call was successful // This function just provides additional logging for tool call details - + for _, toolCall := range toolCalls { if toolCall.Name == string(SampleToolTypeTime) { t.Logf("✅ %s automatic function call: %s", apiName, toolCall.Arguments) diff --git a/core/internal/testutil/complete_end_to_end.go b/core/internal/testutil/complete_end_to_end.go index 852135f49..e0fe00a7d 100644 --- a/core/internal/testutil/complete_end_to_end.go +++ b/core/internal/testutil/complete_end_to_end.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/embedding.go b/core/internal/testutil/embedding.go index b044858b9..ca21aa5f9 100644 --- a/core/internal/testutil/embedding.go +++ b/core/internal/testutil/embedding.go @@ -8,7 +8,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/end_to_end_tool_calling.go b/core/internal/testutil/end_to_end_tool_calling.go index cd9294253..6f1a841a5 100644 --- a/core/internal/testutil/end_to_end_tool_calling.go +++ b/core/internal/testutil/end_to_end_tool_calling.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/image_base64.go b/core/internal/testutil/image_base64.go index b11265319..99db7ff7c 100644 --- a/core/internal/testutil/image_base64.go +++ b/core/internal/testutil/image_base64.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/image_url.go b/core/internal/testutil/image_url.go index db9d93195..88ba12e15 100644 --- a/core/internal/testutil/image_url.go +++ b/core/internal/testutil/image_url.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/list_models.go b/core/internal/testutil/list_models.go index 43be92f2a..92e2868b0 100644 --- a/core/internal/testutil/list_models.go +++ b/core/internal/testutil/list_models.go @@ -5,7 +5,6 @@ import ( "os" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/multi_turn_conversation.go b/core/internal/testutil/multi_turn_conversation.go index e7d471c92..cbdbff86e 100644 --- a/core/internal/testutil/multi_turn_conversation.go +++ b/core/internal/testutil/multi_turn_conversation.go @@ -132,18 +132,18 @@ func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx con expectations2.ShouldContainKeywords = []string{"alice"} // Case insensitive expectations2.ShouldNotContainWords = []string{"don't know", "can't remember", "forgot"} // Memory failure indicators - response2, bifrostErr := WithChatTestRetry(t, chatRetryConfig2, retryContext2, expectations2, "MultiTurnConversation_Step2", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { - return client.ChatCompletionRequest(ctx, secondRequest) - }) + response2, bifrostErr := WithChatTestRetry(t, chatRetryConfig2, retryContext2, expectations2, "MultiTurnConversation_Step2", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return client.ChatCompletionRequest(ctx, secondRequest) + }) - if bifrostErr != nil { - t.Fatalf("❌ MultiTurnConversation_Step2 request failed after retries: %v", GetErrorMessage(bifrostErr)) - } + if bifrostErr != nil { + t.Fatalf("❌ MultiTurnConversation_Step2 request failed after retries: %v", GetErrorMessage(bifrostErr)) + } - // Validation already happened inside WithChatTestRetry via expectations2 - // If we reach here, the model successfully remembered "Alice" - content := GetChatContent(response2) - t.Logf("✅ Model successfully remembered the name: %s", content) - t.Logf("✅ Multi-turn conversation completed successfully") + // Validation already happened inside WithChatTestRetry via expectations2 + // If we reach here, the model successfully remembered "Alice" + content := GetChatContent(response2) + t.Logf("✅ Model successfully remembered the name: %s", content) + t.Logf("✅ Multi-turn conversation completed successfully") }) } diff --git a/core/internal/testutil/multiple_images.go b/core/internal/testutil/multiple_images.go index 9c1ea07b3..b5f65109f 100644 --- a/core/internal/testutil/multiple_images.go +++ b/core/internal/testutil/multiple_images.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/multiple_tool_calls.go b/core/internal/testutil/multiple_tool_calls.go index d1cf7bfbd..08b917795 100644 --- a/core/internal/testutil/multiple_tool_calls.go +++ b/core/internal/testutil/multiple_tool_calls.go @@ -5,7 +5,6 @@ import ( "os" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/simple_chat.go b/core/internal/testutil/simple_chat.go index df9f95d81..99a3dc273 100644 --- a/core/internal/testutil/simple_chat.go +++ b/core/internal/testutil/simple_chat.go @@ -5,7 +5,6 @@ import ( "os" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/text_completion.go b/core/internal/testutil/text_completion.go index 6f5eab740..d14b554f3 100644 --- a/core/internal/testutil/text_completion.go +++ b/core/internal/testutil/text_completion.go @@ -5,7 +5,6 @@ import ( "os" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/text_completion_stream.go b/core/internal/testutil/text_completion_stream.go index c731bd338..3554fb4eb 100644 --- a/core/internal/testutil/text_completion_stream.go +++ b/core/internal/testutil/text_completion_stream.go @@ -8,7 +8,6 @@ import ( "testing" "time" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) diff --git a/core/internal/testutil/tool_calls.go b/core/internal/testutil/tool_calls.go index 739684b1f..3105b268f 100644 --- a/core/internal/testutil/tool_calls.go +++ b/core/internal/testutil/tool_calls.go @@ -7,7 +7,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/require" diff --git a/core/providers/elevenlabs/speech.go b/core/providers/elevenlabs/speech.go index 52028ce7f..b2b91e35c 100644 --- a/core/providers/elevenlabs/speech.go +++ b/core/providers/elevenlabs/speech.go @@ -86,4 +86,4 @@ func ToElevenlabsSpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) *Eleven } return elevenlabsReq -} \ No newline at end of file +} diff --git a/core/providers/openai/types_test.go b/core/providers/openai/types_test.go index 7e5dc5f07..59e9f53cf 100644 --- a/core/providers/openai/types_test.go +++ b/core/providers/openai/types_test.go @@ -460,4 +460,3 @@ func TestOpenAIChatRequest_UnmarshalJSON_ValueAssertions(t *testing.T) { t.Errorf("Expected Stop value ['END', 'STOP'], got %v", req.Stop) } } - diff --git a/core/providers/utils/pagination.go b/core/providers/utils/pagination.go index 8e7905671..a136600fd 100644 --- a/core/providers/utils/pagination.go +++ b/core/providers/utils/pagination.go @@ -109,4 +109,3 @@ func (h *SerialListHelper) HasMoreKeys() bool { } return currentKeyIndex < len(h.Keys)-1 } - diff --git a/core/schemas/pagination.go b/core/schemas/pagination.go index 1d6d9b580..fe39ab0af 100644 --- a/core/schemas/pagination.go +++ b/core/schemas/pagination.go @@ -60,4 +60,3 @@ func NewSerialCursor(keyIndex int, cursor string) *SerialCursor { Cursor: cursor, } } - diff --git a/framework/configstore/config.go b/framework/configstore/config.go index dd061d744..3300f71d3 100644 --- a/framework/configstore/config.go +++ b/framework/configstore/config.go @@ -13,8 +13,8 @@ type ConfigStoreType string // ConfigStoreTypeSQLite is the type of config store for SQLite. const ( - ConfigStoreTypeSQLite ConfigStoreType = "sqlite" - ConfigStoreTypePostgres ConfigStoreType = "postgres" + ConfigStoreTypeSQLite ConfigStoreType = "sqlite" + ConfigStoreTypePostgres ConfigStoreType = "postgres" ) // Config represents the configuration for the config store. diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 472400c7f..a9dccf329 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -665,7 +665,7 @@ func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.Mo if processedARN, err := envutils.ProcessEnvValue(*bedrockConfig.ARN); err == nil { bedrockConfigCopy.ARN = &processedARN } - } + } bedrockConfig = &bedrockConfigCopy } diff --git a/framework/configstore/tables/budget.go b/framework/configstore/tables/budget.go index 1744363b2..eb64ac792 100644 --- a/framework/configstore/tables/budget.go +++ b/framework/configstore/tables/budget.go @@ -27,11 +27,11 @@ type TableBudget struct { func (TableBudget) TableName() string { return "governance_budgets" } // BeforeSave hook for Budget to validate reset duration format and max limit -func (b *TableBudget) BeforeSave(tx *gorm.DB) error { +func (b *TableBudget) BeforeSave(tx *gorm.DB) error { // Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y") if d, err := ParseDuration(b.ResetDuration); err != nil { return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration) - }else if d <= 0 { + } else if d <= 0 { return fmt.Errorf("reset duration must be > 0: %s", b.ResetDuration) } // Validate that MaxLimit is not negative (budgets should be positive) diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index 536d6a565..054626749 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -15,8 +15,8 @@ type TableClientConfig struct { AllowedOriginsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string InitialPoolSize int `gorm:"default:300" json:"initial_pool_size"` EnableLogging bool `gorm:"" json:"enable_logging"` - DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged - LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) + DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged + LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) EnableGovernance bool `gorm:"" json:"enable_governance"` EnforceGovernanceHeader bool `gorm:"" json:"enforce_governance_header"` AllowDirectKeys bool `gorm:"" json:"allow_direct_keys"` @@ -33,7 +33,7 @@ type TableClientConfig struct { // Virtual fields for runtime use (not stored in DB) PrometheusLabels []string `gorm:"-" json:"prometheus_labels"` - AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"` + AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"` } // TableName sets the table name for each model diff --git a/framework/configstore/tables/config.go b/framework/configstore/tables/config.go index cd3cdf82a..3d9570114 100644 --- a/framework/configstore/tables/config.go +++ b/framework/configstore/tables/config.go @@ -10,18 +10,16 @@ const ( ConfigProxyKey = "proxy_config" ) - - // GlobalProxyConfig represents the global proxy configuration type GlobalProxyConfig struct { - Enabled bool `json:"enabled"` - Type network.GlobalProxyType `json:"type"` // "http", "socks5", "tcp" - URL string `json:"url"` // Proxy URL (e.g., http://proxy.example.com:8080) - Username string `json:"username,omitempty"` // Optional authentication username - Password string `json:"password,omitempty"` // Optional authentication password - NoProxy string `json:"no_proxy,omitempty"` // Comma-separated list of hosts to bypass proxy - Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds - SkipTLSVerify bool `json:"skip_tls_verify,omitempty"`// Skip TLS certificate verification + Enabled bool `json:"enabled"` + Type network.GlobalProxyType `json:"type"` // "http", "socks5", "tcp" + URL string `json:"url"` // Proxy URL (e.g., http://proxy.example.com:8080) + Username string `json:"username,omitempty"` // Optional authentication username + Password string `json:"password,omitempty"` // Optional authentication password + NoProxy string `json:"no_proxy,omitempty"` // Comma-separated list of hosts to bypass proxy + Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds + SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` // Skip TLS certificate verification // Entity enablement flags EnableForSCIM bool `json:"enable_for_scim"` // Enable proxy for SCIM requests (enterprise only) EnableForInference bool `json:"enable_for_inference"` // Enable proxy for inference requests diff --git a/framework/configstore/tables/key.go b/framework/configstore/tables/key.go index ede86a175..2e536949e 100644 --- a/framework/configstore/tables/key.go +++ b/framework/configstore/tables/key.go @@ -39,13 +39,13 @@ type TableKey struct { VertexDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string // Bedrock config fields (embedded) - BedrockAccessKey *string `gorm:"type:varchar(255)" json:"bedrock_access_key,omitempty"` - BedrockSecretKey *string `gorm:"type:text" json:"bedrock_secret_key,omitempty"` - BedrockSessionToken *string `gorm:"type:text" json:"bedrock_session_token,omitempty"` - BedrockRegion *string `gorm:"type:varchar(100)" json:"bedrock_region,omitempty"` - BedrockARN *string `gorm:"type:text" json:"bedrock_arn,omitempty"` - BedrockDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config + BedrockAccessKey *string `gorm:"type:varchar(255)" json:"bedrock_access_key,omitempty"` + BedrockSecretKey *string `gorm:"type:text" json:"bedrock_secret_key,omitempty"` + BedrockSessionToken *string `gorm:"type:text" json:"bedrock_session_token,omitempty"` + BedrockRegion *string `gorm:"type:varchar(100)" json:"bedrock_region,omitempty"` + BedrockARN *string `gorm:"type:text" json:"bedrock_arn,omitempty"` + BedrockDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config // Batch API configuration UseForBatchAPI *bool `gorm:"default:false" json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index f5c2381a6..28eaee8c0 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -10,14 +10,14 @@ import ( // TableMCPClient represents an MCP client configuration in the database type TableMCPClient struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. - ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` - Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` - ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType - ConnectionString *string `gorm:"type:text" json:"connection_string,omitempty"` - StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig - ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. + ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` + Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` + ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType + ConnectionString *string `gorm:"type:text" json:"connection_string,omitempty"` + StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig + ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash diff --git a/framework/encrypt/encrypt_test.go b/framework/encrypt/encrypt_test.go index 60cc1f0c7..68d25a5f9 100644 --- a/framework/encrypt/encrypt_test.go +++ b/framework/encrypt/encrypt_test.go @@ -216,7 +216,7 @@ func TestKDFDeterministic(t *testing.T) { // Re-initialize with same passphrase (simulating restart) Init(passphrase, bifrost.NewDefaultLogger(schemas.LogLevelInfo)) - + // Should be able to decrypt the previously encrypted data decrypted, err := Decrypt(encrypted1) if err != nil { diff --git a/plugins/jsonparser/main.go b/plugins/jsonparser/main.go index 7281790e9..245dcdf5b 100644 --- a/plugins/jsonparser/main.go +++ b/plugins/jsonparser/main.go @@ -89,6 +89,7 @@ func (p *JsonParserPlugin) GetName() string { // - url: The URL of the request // - headers: The request headers // - body: The request body +// // Returns: // - map[string]string: The updated request headers // - map[string]any: The updated request body @@ -101,6 +102,7 @@ func (p *JsonParserPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url // Parameters: // - ctx: The Bifrost context // - req: The Bifrost request +// // Returns: // - *schemas.BifrostRequest: The processed request // - *schemas.PluginShortCircuit: The plugin short circuit if the request is not allowed @@ -114,6 +116,7 @@ func (p *JsonParserPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bif // - ctx: The Bifrost context // - result: The Bifrost response to be processed // - err: The Bifrost error to be processed +// // Returns: // - *schemas.BifrostResponse: The processed response // - *schemas.BifrostError: The processed error diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 614c45063..1397e3e33 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -272,7 +272,7 @@ func (p *LoggerPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bifrost initialData.SpeechInput = req.SpeechRequest.Input case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: initialData.Params = req.TranscriptionRequest.Params - initialData.TranscriptionInput = req.TranscriptionRequest.Input + initialData.TranscriptionInput = req.TranscriptionRequest.Input } } diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index 0da4592ff..fb7598daf 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -377,7 +377,7 @@ func (plugin *Plugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostR ctx.SetValue(requestIDKey, requestID) ctx.SetValue(requestModelKey, model) ctx.SetValue(requestProviderKey, provider) - + performDirectSearch, performSemanticSearch := true, true if (*ctx).Value(CacheTypeKey) != nil { cacheTypeVal, ok := (*ctx).Value(CacheTypeKey).(CacheType) diff --git a/plugins/semanticcache/plugin_integration_test.go b/plugins/semanticcache/plugin_integration_test.go index 602873613..eccf415db 100644 --- a/plugins/semanticcache/plugin_integration_test.go +++ b/plugins/semanticcache/plugin_integration_test.go @@ -18,7 +18,7 @@ func TestSemanticCacheBasicFlow(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(CacheKey, "test-cache-enabled") - + // Test request request := &schemas.BifrostRequest{ RequestType: schemas.ChatCompletionRequest, @@ -309,7 +309,7 @@ func TestSemanticCacheStreamingFlow(t *testing.T) { setup := NewTestSetup(t) defer setup.Cleanup() - ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(CacheKey, "test-cache-enabled") request := &schemas.BifrostRequest{ @@ -547,7 +547,7 @@ func TestSemanticCache_CustomThresholdHandling(t *testing.T) { defer setup.Cleanup() // Configure plugin with custom threshold key - ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(CacheKey, "test-cache-enabled") ctx.SetValue(CacheThresholdKey, 0.95) // Very high threshold diff --git a/plugins/semanticcache/search.go b/plugins/semanticcache/search.go index 1872e3624..b3dc0cb32 100644 --- a/plugins/semanticcache/search.go +++ b/plugins/semanticcache/search.go @@ -297,7 +297,7 @@ func (plugin *Plugin) buildStreamingResponseFromResult(ctx *schemas.BifrostConte // Mark cache-hit once to avoid concurrent ctx writes ctx.SetValue(isCacheHitKey, true) ctx.SetValue(cacheHitTypeKey, cacheType) - + // Create stream channel streamChan := make(chan *schemas.BifrostStream) diff --git a/plugins/semanticcache/stream.go b/plugins/semanticcache/stream.go index bd9f19e05..b11df493d 100644 --- a/plugins/semanticcache/stream.go +++ b/plugins/semanticcache/stream.go @@ -71,11 +71,11 @@ func (plugin *Plugin) processAccumulatedStream(ctx context.Context, requestID st accumulator := accumulatorInterface.(*StreamAccumulator) accumulator.mu.Lock() - + // Ensure unlock happens after cleanup defer accumulator.mu.Unlock() // Ensure cleanup happens - defer plugin.cleanupStreamAccumulator(requestID) + defer plugin.cleanupStreamAccumulator(requestID) // STEP 1: Check if any chunk in the entire stream had an error if accumulator.HasError { diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index cc04a411b..3e4000afb 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -285,29 +285,29 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { } } - // Get keys for this provider config if specified - var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { - var err error - keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) - if err != nil { - return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) - } - if len(keys) != len(pc.KeyIDs) { - return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + // Get keys for this provider config if specified + var keys []configstoreTables.TableKey + if len(pc.KeyIDs) > 0 { + var err error + keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) + if err != nil { + return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) + } + if len(keys) != len(pc.KeyIDs) { + return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + } } - } - providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ - VirtualKeyID: vk.ID, - Provider: pc.Provider, - Weight: pc.Weight, - AllowedModels: pc.AllowedModels, - Keys: keys, - } + providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ + VirtualKeyID: vk.ID, + Provider: pc.Provider, + Weight: pc.Weight, + AllowedModels: pc.AllowedModels, + Keys: keys, + } - // Create budget for provider config if provided - if pc.Budget != nil { + // Create budget for provider config if provided + if pc.Budget != nil { budget := configstoreTables.TableBudget{ ID: uuid.NewString(), MaxLimit: pc.Budget.MaxLimit, @@ -588,29 +588,29 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return fmt.Errorf("both max_limit and reset_duration are required when creating a new provider budget") } } - // Get keys for this provider config if specified - var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { - var err error - keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) - if err != nil { - return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) - } - if len(keys) != len(pc.KeyIDs) { - return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) - } - } + // Get keys for this provider config if specified + var keys []configstoreTables.TableKey + if len(pc.KeyIDs) > 0 { + var err error + keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) + if err != nil { + return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) + } + if len(keys) != len(pc.KeyIDs) { + return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + } + } - // Create new provider config - providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ - VirtualKeyID: vk.ID, - Provider: pc.Provider, - Weight: pc.Weight, - AllowedModels: pc.AllowedModels, - Keys: keys, - } - // Create budget for provider config if provided - if pc.Budget != nil { + // Create new provider config + providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ + VirtualKeyID: vk.ID, + Provider: pc.Provider, + Weight: pc.Weight, + AllowedModels: pc.AllowedModels, + Keys: keys, + } + // Create budget for provider config if provided + if pc.Budget != nil { budget := configstoreTables.TableBudget{ ID: uuid.NewString(), MaxLimit: *pc.Budget.MaxLimit, @@ -648,32 +648,32 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { if err := h.configStore.CreateVirtualKeyProviderConfig(ctx, providerConfig, tx); err != nil { return err } - } else { - // Update existing provider config - existing, ok := existingConfigsMap[*pc.ID] - if !ok { - return fmt.Errorf("provider config %d does not belong to this virtual key", *pc.ID) - } - requestConfigsMap[*pc.ID] = true - existing.Provider = pc.Provider - existing.Weight = pc.Weight - existing.AllowedModels = pc.AllowedModels - - // Get keys for this provider config if specified - var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { - var err error - keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) - if err != nil { - return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) + } else { + // Update existing provider config + existing, ok := existingConfigsMap[*pc.ID] + if !ok { + return fmt.Errorf("provider config %d does not belong to this virtual key", *pc.ID) } - if len(keys) != len(pc.KeyIDs) { - return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + requestConfigsMap[*pc.ID] = true + existing.Provider = pc.Provider + existing.Weight = pc.Weight + existing.AllowedModels = pc.AllowedModels + + // Get keys for this provider config if specified + var keys []configstoreTables.TableKey + if len(pc.KeyIDs) > 0 { + var err error + keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) + if err != nil { + return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) + } + if len(keys) != len(pc.KeyIDs) { + return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + } } - } - existing.Keys = keys + existing.Keys = keys - // Handle budget updates for provider config + // Handle budget updates for provider config if pc.Budget != nil { if existing.BudgetID != nil { // Update existing budget diff --git a/transports/bifrost-http/handlers/middlewares_test.go b/transports/bifrost-http/handlers/middlewares_test.go index 395beb2ba..12ddb5894 100644 --- a/transports/bifrost-http/handlers/middlewares_test.go +++ b/transports/bifrost-http/handlers/middlewares_test.go @@ -12,12 +12,12 @@ import ( // mockLogger is a mock implementation of schemas.Logger for testing type mockLogger struct{} -func (m *mockLogger) Debug(format string, args ...any) {} -func (m *mockLogger) Info(format string, args ...any) {} -func (m *mockLogger) Warn(format string, args ...any) {} -func (m *mockLogger) Error(format string, args ...any) {} -func (m *mockLogger) Fatal(format string, args ...any) {} -func (m *mockLogger) SetLevel(level schemas.LogLevel) {} +func (m *mockLogger) Debug(format string, args ...any) {} +func (m *mockLogger) Info(format string, args ...any) {} +func (m *mockLogger) Warn(format string, args ...any) {} +func (m *mockLogger) Error(format string, args ...any) {} +func (m *mockLogger) Fatal(format string, args ...any) {} +func (m *mockLogger) SetLevel(level schemas.LogLevel) {} func (m *mockLogger) SetOutputType(outputType schemas.LoggerOutputType) {} // TestCorsMiddleware_LocalhostOrigins tests that localhost origins are always allowed diff --git a/transports/bifrost-http/integrations/bedrock_test.go b/transports/bifrost-http/integrations/bedrock_test.go index 2556676b0..e03131d5f 100644 --- a/transports/bifrost-http/integrations/bedrock_test.go +++ b/transports/bifrost-http/integrations/bedrock_test.go @@ -530,12 +530,12 @@ func Test_extractBedrockJobArnFromPath(t *testing.T) { handlerStore := &mockHandlerStore{allowDirectKeys: false} tests := []struct { - name string - jobArn interface{} - provider schemas.ModelProvider - wantErr bool - wantJobArn string - errContains string + name string + jobArn interface{} + provider schemas.ModelProvider + wantErr bool + wantJobArn string + errContains string }{ { name: "valid job ARN for Bedrock", diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index ebdfaf069..566ab50a8 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -500,7 +500,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt if createReq, ok := req.(*schemas.BifrostBatchCreateRequest); ok { if createReq.Provider == "" { createReq.Provider = schemas.OpenAI - } + } // For Bedrock, extract extra params from raw body // ExtraParams has json:"-" tag so it's not auto-populated if createReq.Provider == schemas.Bedrock { @@ -1202,7 +1202,7 @@ func parseOpenAIFileUploadMultipartRequest(ctx *fasthttp.RequestCtx, req interfa purposeValues := form.Value["purpose"] if len(purposeValues) == 0 || purposeValues[0] == "" { return errors.New("purpose field is required") - } + } uploadReq.Purpose = schemas.FilePurpose(purposeValues[0]) // Extract file (required)