diff --git a/.editorconfig b/.editorconfig index 7223b342a..71e2db663 100644 --- a/.editorconfig +++ b/.editorconfig @@ -5,5 +5,9 @@ insert_final_newline = false end_of_line = lf charset = utf-8 +[*.go] +indent_style = tab +indent_size = 4 + [*.{js,jsx,ts,tsx,mjs,json,md,css,scss,html}] insert_final_newline = false diff --git a/core/bifrost.go b/core/bifrost.go index c53a09e04..4d8a021b8 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/providers/anthropic" "github.com/maximhq/bifrost/core/providers/azure" "github.com/maximhq/bifrost/core/providers/bedrock" @@ -66,7 +67,8 @@ type Bifrost struct { pluginPipelinePool sync.Pool // Pool for PluginPipeline objects bifrostRequestPool sync.Pool // Pool for BifrostRequest objects logger schemas.Logger // logger instance, default logger is used if not provided - mcpManager *MCPManager // MCP integration manager (nil if MCP not configured) + mcpManager *mcp.MCPManager // MCP integration manager (nil if MCP not configured) + mcpInitOnce sync.Once // Ensures MCP manager is initialized only once dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. keySelector schemas.KeySelector // Custom key selector function } @@ -179,13 +181,10 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { // Initialize MCP manager if configured if config.MCPConfig != nil { - mcpManager, err := newMCPManager(bifrostCtx, *config.MCPConfig, bifrost.logger) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("failed to initialize MCP manager: %v", err)) - } else { - bifrost.mcpManager = mcpManager + bifrost.mcpInitOnce.Do(func() { + bifrost.mcpManager = mcp.NewMCPManager(bifrostCtx, *config.MCPConfig, bifrost.logger) bifrost.logger.Info("MCP integration initialized successfully") - } + }) } // Create buffered channels for each provider and start workers @@ -539,8 +538,7 @@ func (bifrost *Bifrost) TextCompletionStreamRequest(ctx context.Context, req *sc return bifrost.handleStreamRequest(ctx, bifrostReq) } -// ChatCompletionRequest sends a chat completion request to the specified provider. -func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) makeChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -574,10 +572,35 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. if err != nil { return nil, err } - //TODO: Release the response + return response.ChatResponse, nil } +// ChatCompletionRequest sends a chat completion request to the specified provider. +func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + // If ctx is nil, use the bifrost context (defensive check for mcp agent mode) + if ctx == nil { + ctx = bifrost.ctx + } + + response, err := bifrost.makeChatCompletionRequest(ctx, req) + if err != nil { + return nil, err + } + + // Check if we should enter agent mode + if bifrost.mcpManager != nil { + return bifrost.mcpManager.CheckAndExecuteAgentForChatRequest( + &ctx, + req, + response, + bifrost.makeChatCompletionRequest, + ) + } + + return response, nil +} + // ChatCompletionStreamRequest sends a chat completion stream request to the specified provider. func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req == nil { @@ -612,8 +635,7 @@ func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *sc return bifrost.handleStreamRequest(ctx, bifrostReq) } -// ResponsesRequest sends a responses request to the specified provider. -func (bifrost *Bifrost) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) makeResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -647,10 +669,34 @@ func (bifrost *Bifrost) ResponsesRequest(ctx context.Context, req *schemas.Bifro if err != nil { return nil, err } - //TODO: Release the response return response.ResponsesResponse, nil } +// ResponsesRequest sends a responses request to the specified provider. +func (bifrost *Bifrost) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + // If ctx is nil, use the bifrost context (defensive check for mcp agent mode) + if ctx == nil { + ctx = bifrost.ctx + } + + response, err := bifrost.makeResponsesRequest(ctx, req) + if err != nil { + return nil, err + } + + // Check if we should enter agent mode + if bifrost.mcpManager != nil { + return bifrost.mcpManager.CheckAndExecuteAgentForResponsesRequest( + &ctx, + req, + response, + bifrost.makeResponsesRequest, + ) + } + + return response, nil +} + // ResponsesStreamRequest sends a responses stream request to the specified provider. func (bifrost *Bifrost) ResponsesStreamRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req == nil { @@ -1683,10 +1729,10 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.registerTool(name, description, handler, toolSchema) + return bifrost.mcpManager.RegisterTool(name, description, handler, toolSchema) } -// ExecuteMCPTool executes an MCP tool call and returns the result as a tool message. +// ExecuteChatMCPTool executes an MCP tool call and returns the result as a chat message. // This is the main public API for manual MCP tool execution. // // Parameters: @@ -1696,7 +1742,7 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a // Returns: // - schemas.ChatMessage: Tool message with execution result // - schemas.BifrostError: Any execution error -func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) { +func (bifrost *Bifrost) ExecuteChatMCPTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) { if bifrost.mcpManager == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1709,13 +1755,12 @@ func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.Cha } } - result, err := bifrost.mcpManager.executeTool(ctx, toolCall) + result, err := bifrost.mcpManager.ExecuteChatTool(ctx, toolCall) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: err.Error(), - Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ChatCompletionRequest, // MCP tools are used with chat completions @@ -1726,6 +1771,44 @@ func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.Cha return result, nil } +// ExecuteResponsesMCPTool executes an MCP tool call and returns the result as a responses message. + +// Parameters: +// - ctx: Execution context +// - toolCall: The tool call to execute (from assistant message) +// +// Returns: +// - schemas.ResponsesMessage: Tool message with execution result +// - schemas.BifrostError: Any execution error +func (bifrost *Bifrost) ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError) { + if bifrost.mcpManager == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "MCP is not configured in this Bifrost instance", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.ResponsesRequest, // MCP tools are used with responses requests + }, + } + } + + result, err := bifrost.mcpManager.ExecuteResponsesTool(ctx, toolCall) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: err.Error(), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.ResponsesRequest, // MCP tools are used with responses requests + }, + } + } + + return result, nil +} + // IMPORTANT: Running the MCP client management operations (GetMCPClients, AddMCPClient, RemoveMCPClient, EditMCPClientTools) // may temporarily increase latency for incoming requests while the operations are being processed. // These operations involve network I/O and connection management that require mutex locks @@ -1741,12 +1824,9 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { return nil, fmt.Errorf("MCP is not configured in this Bifrost instance") } - clients, err := bifrost.mcpManager.GetClients() - if err != nil { - return nil, err - } - + clients := bifrost.mcpManager.GetClients() clientsInConfig := make([]schemas.MCPClient, 0, len(clients)) + for _, client := range clients { tools := make([]schemas.ChatToolFunction, 0, len(client.ToolMap)) for _, tool := range client.ToolMap { @@ -1759,21 +1839,27 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { return tools[i].Name < tools[j].Name }) - state := schemas.MCPConnectionStateConnected - if client.Conn == nil { - state = schemas.MCPConnectionStateDisconnected - } - clientsInConfig = append(clientsInConfig, schemas.MCPClient{ Config: client.ExecutionConfig, Tools: tools, - State: state, + State: client.State, }) } return clientsInConfig, nil } +// GetAvailableTools returns the available tools for the given context. +// +// Returns: +// - []schemas.ChatTool: List of available tools +func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { + if bifrost.mcpManager == nil { + return nil + } + return bifrost.mcpManager.GetAvailableTools(ctx) +} + // AddMCPClient adds a new MCP client to the Bifrost instance. // This allows for dynamic MCP client management at runtime. // @@ -1792,13 +1878,17 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { // }) func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { if bifrost.mcpManager == nil { - manager := &MCPManager{ - ctx: bifrost.ctx, - clientMap: make(map[string]*MCPClient), - logger: bifrost.logger, - } + // Use sync.Once to ensure thread-safe initialization + bifrost.mcpInitOnce.Do(func() { + bifrost.mcpManager = mcp.NewMCPManager(bifrost.ctx, schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{config}, + }, bifrost.logger) + }) + } - bifrost.mcpManager = manager + // Handle case where initialization succeeded elsewhere but manager is still nil + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP manager is not initialized") } return bifrost.mcpManager.AddClient(config) @@ -1866,6 +1956,22 @@ func (bifrost *Bifrost) ReconnectMCPClient(id string) error { return bifrost.mcpManager.ReconnectClient(id) } +// UpdateToolManagerConfig updates the tool manager config for the MCP manager. +// This allows for hot-reloading of the tool manager config at runtime. +func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP is not configured in this Bifrost instance") + } + + bifrost.mcpManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ + MaxAgentDepth: maxAgentDepth, + ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, + CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), + }) + return nil +} + + // PROVIDER MANAGEMENT // createBaseProvider creates a provider based on the base provider type @@ -2378,11 +2484,8 @@ func (bifrost *Bifrost) tryRequest(ctx context.Context, req *schemas.BifrostRequ } // Add MCP tools to request if MCP is configured and requested - if req.RequestType != schemas.EmbeddingRequest && - req.RequestType != schemas.SpeechRequest && - req.RequestType != schemas.TranscriptionRequest && - bifrost.mcpManager != nil { - req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) + if bifrost.mcpManager != nil { + req = bifrost.mcpManager.AddToolsToRequest(ctx, req) } pipeline := bifrost.getPluginPipeline() @@ -2498,7 +2601,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.Bifro // Add MCP tools to request if MCP is configured and requested if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.mcpManager != nil { - req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) + req = bifrost.mcpManager.AddToolsToRequest(ctx, req) } pipeline := bifrost.getPluginPipeline() @@ -2690,7 +2793,7 @@ func executeRequestWithRetries[T any]( // Calculate and apply backoff backoff := calculateBackoff(attempts-1, config) - logger.Debug("sleeping for %s", backoff) + logger.Debug("sleeping for %s before retry", backoff) time.Sleep(backoff) } @@ -3516,7 +3619,7 @@ func (bifrost *Bifrost) Shutdown() { // Cleanup MCP manager if bifrost.mcpManager != nil { - err := bifrost.mcpManager.cleanup() + err := bifrost.mcpManager.Cleanup() if err != nil { bifrost.logger.Warn(fmt.Sprintf("Error cleaning up MCP manager: %s", err.Error())) } diff --git a/core/changelog.md b/core/changelog.md index e69de29bb..b717cf5e0 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -0,0 +1,3 @@ +- feat: added code mode to mcp +- feat: added health monitoring to mcp +- feat: added responses format tool execution support to mcp \ No newline at end of file diff --git a/core/chatbot_test.go b/core/chatbot_test.go index ff8cef7a2..67aae2972 100644 --- a/core/chatbot_test.go +++ b/core/chatbot_test.go @@ -563,7 +563,7 @@ func (s *ChatSession) handleToolCalls(assistantMessage schemas.ChatMessage) (str stopChan, wg := startLoader() // Execute the tool using Bifrost's integrated MCP functionality - toolResult, err := s.client.ExecuteMCPTool(context.Background(), toolCall) + toolResult, err := s.client.ExecuteChatMCPTool(context.Background(), toolCall) // Stop loading animation stopLoader(stopChan, wg) diff --git a/core/go.mod b/core/go.mod index 194bf61ed..724375d92 100644 --- a/core/go.mod +++ b/core/go.mod @@ -10,6 +10,8 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 github.com/aws/smithy-go v1.24.0 github.com/bytedance/sonic v1.14.2 + github.com/clarkmcc/go-typescript v0.7.0 + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 github.com/google/uuid v1.6.0 github.com/hajimehoshi/go-mp3 v0.3.4 github.com/mark3labs/mcp-go v0.43.2 @@ -22,6 +24,7 @@ require ( require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect + github.com/Masterminds/semver/v3 v3.3.1 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect @@ -42,6 +45,9 @@ require ( github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect diff --git a/core/go.sum b/core/go.sum index a23bef3a4..656faafb2 100644 --- a/core/go.sum +++ b/core/go.sum @@ -1,5 +1,7 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -7,7 +9,9 @@ github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXx github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 h1:489krEF9xIGkOaaX3CE/Be2uWjiXrkCH6gUX+bZA/BU= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4/go.mod h1:IOAPF6oT9KCsceNTvvYMNHy0+kMF8akOjeDvPENWxp4= github.com/aws/aws-sdk-go-v2/config v1.32.6 h1:hFLBGUKjmLAekvi1evLi5hVvFQtSo3GYwi+Bx4lpJf8= +github.com/aws/aws-sdk-go-v2/config v1.32.6/go.mod h1:lcUL/gcd8WyjCrMnxez5OXkO3/rwcNmvfno62tnXNcI= github.com/aws/aws-sdk-go-v2/credentials v1.19.6 h1:F9vWao2TwjV2MyiyVS+duza0NIRtAslgLUM0vTA1ZaE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.6/go.mod h1:SgHzKjEVsdQr6Opor0ihgWtkWdfRAIwxYzSJ8O85VHY= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= @@ -27,9 +31,11 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:i github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 h1:NSbvS17MlI2lurYgXnCOLvCFX38sBW4eiVER7+kkgsU= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16/go.mod h1:SwT8Tmqd4sA6G1qaGdzWCJN99bUmPGHfRwwq3G5Qb+A= github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 h1:SWTxh/EcUCDVqi/0s26V6pVUq0BBG7kx0tDTmF/hCgA= +github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0/go.mod h1:79S2BdqCJpScXZA2y+cpZuocWsjGjJINyXnOsf5DTz8= github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 h1:aM/Q24rIlS3bRAhTyFurowU8A0SMyGDtEOY/l/s/1Uw= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.8/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= @@ -43,7 +49,11 @@ github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx2 github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPIIE= +github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= +github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -51,11 +61,19 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo= +github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -64,6 +82,7 @@ github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/Sisy github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= @@ -76,6 +95,7 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -97,8 +117,11 @@ github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qq github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= @@ -106,6 +129,7 @@ github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2 github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok= +github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= @@ -113,17 +137,24 @@ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3i github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= +golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/core/mcp.go b/core/mcp.go index b1eb7a73e..b409cf49e 100644 --- a/core/mcp.go +++ b/core/mcp.go @@ -1136,14 +1136,8 @@ func (m *MCPManager) createInProcessConnection(config schemas.MCPClientConfig) ( return nil, MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") } - // Type assert to ensure we have a proper MCP server - mcpServer, ok := config.InProcessServer.(*server.MCPServer) - if !ok { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("InProcessServer must be a *server.MCPServer instance") - } - // Create in-process client directly connected to the provided server - inProcessClient, err := client.NewInProcessClient(mcpServer) + inProcessClient, err := client.NewInProcessClient(config.InProcessServer) if err != nil { return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create in-process client: %w", err) } diff --git a/core/mcp/agent.go b/core/mcp/agent.go new file mode 100644 index 000000000..bd2b6da6d --- /dev/null +++ b/core/mcp/agent.go @@ -0,0 +1,473 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +// ExecuteAgentForChatRequest handles the agent mode execution loop for Chat API. +// It orchestrates iterative tool execution up to the maximum depth, handling +// auto-executable and non-auto-executable tools appropriately. +// +// Parameters: +// - ctx: Context for agent execution +// - maxAgentDepth: Maximum number of agent iterations allowed +// - originalReq: The original chat request +// - initialResponse: The initial chat response containing tool calls +// - makeReq: Function to make subsequent chat requests during agent execution +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration +// - executeToolFunc: Function to execute individual tool calls +// - clientManager: Client manager for accessing MCP clients and tools +// +// Returns: +// - *schemas.BifrostChatResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func ExecuteAgentForChatRequest( + ctx *context.Context, + maxAgentDepth int, + originalReq *schemas.BifrostChatRequest, + initialResponse *schemas.BifrostChatResponse, + makeReq func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), + fetchNewRequestIDFunc func(ctx context.Context) string, + executeToolFunc func(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error), + clientManager ClientManager, +) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + // Create adapter for Chat API + adapter := &chatAPIAdapter{ + originalReq: originalReq, + initialResponse: initialResponse, + makeReq: makeReq, + } + + result, err := executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager) + if err != nil { + return nil, err + } + + chatResponse, ok := result.(*schemas.BifrostChatResponse) + // Should never happen, but just in case + if !ok { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "Failed to convert result to schemas.BifrostChatResponse", + }, + } + } + + return chatResponse, nil +} + +// ExecuteAgentForResponsesRequest handles the agent mode execution loop for Responses API. +// It orchestrates iterative tool execution up to the maximum depth, handling +// auto-executable and non-auto-executable tools appropriately. +// +// Parameters: +// - ctx: Context for agent execution +// - maxAgentDepth: Maximum number of agent iterations allowed +// - originalReq: The original responses request +// - initialResponse: The initial responses response containing tool calls +// - makeReq: Function to make subsequent responses requests during agent execution +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration +// - executeToolFunc: Function to execute individual tool calls +// - clientManager: Client manager for accessing MCP clients and tools +// +// Returns: +// - *schemas.BifrostResponsesResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func ExecuteAgentForResponsesRequest( + ctx *context.Context, + maxAgentDepth int, + originalReq *schemas.BifrostResponsesRequest, + initialResponse *schemas.BifrostResponsesResponse, + makeReq func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), + fetchNewRequestIDFunc func(ctx context.Context) string, + executeToolFunc func(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error), + clientManager ClientManager, +) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + // Create adapter for Responses API + adapter := &responsesAPIAdapter{ + originalReq: originalReq, + initialResponse: initialResponse, + makeReq: makeReq, + } + + result, err := executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager) + if err != nil { + return nil, err + } + + responsesResponse, ok := result.(*schemas.BifrostResponsesResponse) + // Should never happen, but just in case + if !ok { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "Failed to convert result to schemas.BifrostResponsesResponse", + }, + } + } + + return responsesResponse, nil +} + +// executeAgent handles the generic agent mode execution loop using an API adapter pattern. +// It iteratively executes tools, separates auto-executable from non-auto-executable tools, +// executes auto-executable tools in parallel, and continues the loop until no more tool +// calls are present or the maximum depth is reached. +// +// Parameters: +// - ctx: Context for agent execution (may be modified to add request IDs) +// - maxAgentDepth: Maximum number of agent iterations allowed +// - adapter: API adapter that abstracts differences between Chat and Responses APIs +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration +// - executeToolFunc: Function to execute individual tool calls +// - clientManager: Client manager for accessing MCP clients and tools +// +// Returns: +// - interface{}: The final response after agent execution (type depends on adapter) +// - *schemas.BifrostError: Any error that occurred during agent execution +func executeAgent( + ctx *context.Context, + maxAgentDepth int, + adapter agentAPIAdapter, + fetchNewRequestIDFunc func(ctx context.Context) string, + executeToolFunc func(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error), + clientManager ClientManager, +) (interface{}, *schemas.BifrostError) { + logger.Debug("Entering agent mode - detected tool calls in response") + + // Get initial response from adapter + currentResponse := adapter.getInitialResponse() + + // Create conversation history starting with original messages + conversationHistory := adapter.getConversationHistory() + + depth := 0 + + // Track all executed tool results and tool calls across all iterations + allExecutedToolResults := make([]*schemas.ChatMessage, 0) + allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + + originalRequestID, ok := (*ctx).Value(schemas.BifrostContextKeyRequestID).(string) + if ok { + *ctx = context.WithValue(*ctx, schemas.BifrostMCPAgentOriginalRequestID, originalRequestID) + } + + for depth < maxAgentDepth { + depth++ + toolCalls := adapter.extractToolCalls(currentResponse) + if len(toolCalls) == 0 { + logger.Debug("No more tool calls found, exiting agent mode") + break + } + + logger.Debug(fmt.Sprintf("Agent mode depth %d: executing %d tool calls", depth, len(toolCalls))) + + // Separate tools into auto-executable and non-auto-executable groups + var autoExecutableTools []schemas.ChatAssistantMessageToolCall + var nonAutoExecutableTools []schemas.ChatAssistantMessageToolCall + + for _, toolCall := range toolCalls { + if toolCall.Function.Name == nil { + // Skip tools without names + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + toolName := *toolCall.Function.Name + client := clientManager.GetClientForTool(toolName) + if client == nil { + // Allow code mode list and read tool tools + if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile { + autoExecutableTools = append(autoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s can be auto-executed", toolName)) + continue + } else if toolName == ToolTypeExecuteToolCode { + // Build allowed auto-execution tools map for code mode validation + allClientNames, allowedAutoExecutionTools := buildAllowedAutoExecutionTools(*ctx, clientManager) + + // Parse tool arguments + var arguments map[string]interface{} + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + logger.Debug(fmt.Sprintf("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + code, ok := arguments["code"].(string) + if !ok || code == "" { + logger.Debug(fmt.Sprintf("%s Code parameter missing or empty", CodeModeLogPrefix)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + // Step 1: Convert literal \n escape sequences to actual newlines for parsing + codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") + if len(codeWithNewlines) != len(code) { + logger.Debug(fmt.Sprintf("%s Converted literal \\n escape sequences to actual newlines", CodeModeLogPrefix)) + } + + // Step 2: Extract tool calls from code during AST formation + extractedToolCalls, err := extractToolCallsFromCode(codeWithNewlines) + if err != nil { + logger.Debug(fmt.Sprintf("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + logger.Debug(fmt.Sprintf("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls))) + + // Step 3: Validate all tool calls against allowedAutoExecutionTools + canAutoExecute := true + if len(extractedToolCalls) > 0 { + // If there are tool calls, we need allowedAutoExecutionTools to validate them + if len(allowedAutoExecutionTools) == 0 { + logger.Debug(fmt.Sprintf("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix)) + canAutoExecute = false + } else { + logger.Debug(fmt.Sprintf("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools))) + + // Validate each tool call + for _, extractedToolCall := range extractedToolCalls { + isAllowed := isToolCallAllowedForCodeMode(extractedToolCall.serverName, extractedToolCall.toolName, allClientNames, allowedAutoExecutionTools) + if !isAllowed { + logger.Debug(fmt.Sprintf("%s Tool call %s.%s: allowed=%v", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName, isAllowed)) + logger.Debug(fmt.Sprintf("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName)) + canAutoExecute = false + break + } + } + if canAutoExecute { + logger.Debug(fmt.Sprintf("%s All tool calls validated successfully", CodeModeLogPrefix)) + } + } + } else { + logger.Debug(fmt.Sprintf("%s No tool calls found in code, skipping validation", CodeModeLogPrefix)) + } + + // Add to appropriate list based on validation result + if canAutoExecute { + autoExecutableTools = append(autoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s can be auto-executed (validation passed)", toolName)) + } else { + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s cannot be auto-executed (validation failed)", toolName)) + } + continue + } + // Else, if client not found, treat as non-auto-executable (can be a manually passed tool) + logger.Debug(fmt.Sprintf("Client not found for tool %s, treating as non-auto-executable", toolName)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + // Check if tool can be auto-executed + if canAutoExecuteTool(toolName, client.ExecutionConfig) { + autoExecutableTools = append(autoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s can be auto-executed", toolName)) + } else { + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s cannot be auto-executed", toolName)) + } + } + + logger.Debug(fmt.Sprintf("Auto-executable tools: %d", len(autoExecutableTools))) + logger.Debug(fmt.Sprintf("Non-auto-executable tools: %d", len(nonAutoExecutableTools))) + + // Execute auto-executable tools first + var executedToolResults []*schemas.ChatMessage + if len(autoExecutableTools) > 0 { + // Add assistant message with auto-executable tool calls to conversation + conversationHistory = adapter.addAssistantMessage(conversationHistory, currentResponse) + + // Execute all auto-executable tool calls parallelly + wg := sync.WaitGroup{} + wg.Add(len(autoExecutableTools)) + channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools)) + for _, toolCall := range autoExecutableTools { + go func(toolCall schemas.ChatAssistantMessageToolCall) { + defer wg.Done() + toolResult, toolErr := executeToolFunc(*ctx, toolCall) + if toolErr != nil { + logger.Warn(fmt.Sprintf("Tool execution failed: %v", toolErr)) + channelToolResults <- createToolResultMessage(toolCall, "", toolErr) + } else { + channelToolResults <- toolResult + } + }(toolCall) + } + wg.Wait() + close(channelToolResults) + + // Collect tool results + executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools)) + for toolResult := range channelToolResults { + executedToolResults = append(executedToolResults, toolResult) + } + + // Track executed tool results and calls across all iterations + allExecutedToolResults = append(allExecutedToolResults, executedToolResults...) + allExecutedToolCalls = append(allExecutedToolCalls, autoExecutableTools...) + + // Add tool results to conversation history + conversationHistory = adapter.addToolResults(conversationHistory, executedToolResults) + } + + // If there are non-auto-executable tools, return them immediately without continuing the loop + if len(nonAutoExecutableTools) > 0 { + logger.Debug(fmt.Sprintf("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools))) + // Return as is if its the first iteration + if depth == 1 && len(allExecutedToolResults) == 0 { + return currentResponse, nil + } + // Create response with all executed tool results from all iterations, and non-auto-executable tool calls + return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil + } + + // Create new request with updated conversation history + newReq := adapter.createNewRequest(conversationHistory) + + if fetchNewRequestIDFunc != nil { + newID := fetchNewRequestIDFunc(*ctx) + if newID != "" { + *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyRequestID, newID) + } + } + + // Make new LLM request + response, err := adapter.makeLLMCall(*ctx, newReq) + if err != nil { + logger.Error("Agent mode: LLM request failed: %v", err) + return nil, err + } + + currentResponse = response + } + + logger.Debug(fmt.Sprintf("Agent mode completed after %d iterations", depth)) + return currentResponse, nil +} + +// extractToolCalls extracts all tool calls from a chat response. +// It iterates through all choices in the response and collects tool calls +// from assistant messages. +// +// Parameters: +// - response: The chat response to extract tool calls from +// +// Returns: +// - []schemas.ChatAssistantMessageToolCall: List of extracted tool calls, or nil if none found +func extractToolCalls(response *schemas.BifrostChatResponse) []schemas.ChatAssistantMessageToolCall { + if !hasToolCallsForChatResponse(response) { + return nil + } + + var toolCalls []schemas.ChatAssistantMessageToolCall + for _, choice := range response.Choices { + if choice.ChatNonStreamResponseChoice != nil && + choice.ChatNonStreamResponseChoice.Message != nil && + choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil { + toolCalls = append(toolCalls, choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls...) + } + } + + return toolCalls +} + +// createToolResultMessage creates a tool result message from tool execution. +// It formats the result or error into a chat message with the appropriate tool call ID. +// +// Parameters: +// - toolCall: The original tool call that was executed +// - result: The successful execution result (ignored if err is not nil) +// - err: Any error that occurred during tool execution +// +// Returns: +// - *schemas.ChatMessage: A tool message containing the execution result or error +func createToolResultMessage(toolCall schemas.ChatAssistantMessageToolCall, result string, err error) *schemas.ChatMessage { + var content string + if err != nil { + content = fmt.Sprintf("Error executing tool %s: %s", + func() string { + if toolCall.Function.Name != nil { + return *toolCall.Function.Name + } + return "unknown" + }(), err.Error()) + } else { + content = result + } + + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: &content, + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: toolCall.ID, + }, + } +} + +// buildAllowedAutoExecutionTools builds a map of client names to their auto-executable tools. +// It processes code mode clients and parses their ToolsToAutoExecute configuration to create +// a map of allowed tools. Tool names are parsed to match their appearance in JavaScript code. +// +// Parameters: +// - ctx: Context for accessing client tools +// - clientManager: Client manager for accessing MCP clients +// +// Returns: +// - []string: List of all client names +// - map[string][]string: Map of client names to their auto-executable tool names (as they appear in code) +func buildAllowedAutoExecutionTools(ctx context.Context, clientManager ClientManager) ([]string, map[string][]string) { + allowedTools := make(map[string][]string) + availableToolsPerClient := clientManager.GetToolPerClient(ctx) + allClientNames := []string{} + + for clientName := range availableToolsPerClient { + client := clientManager.GetClientByName(clientName) + if client == nil { + continue + } + allClientNames = append(allClientNames, clientName) + + // Only include code mode clients + if !client.ExecutionConfig.IsCodeModeClient { + continue + } + + // Get auto-executable tools from config + toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute + if len(toolsToAutoExecute) == 0 { + // No auto-executable tools configured for this client + continue + } + + // Parse tool names (as they appear in JavaScript code) + autoExecutableTools := []string{} + for _, originalToolName := range toolsToAutoExecute { + // Handle wildcard "*" - means all tools are auto-executable + if originalToolName == "*" { + autoExecutableTools = append(autoExecutableTools, "*") + continue + } + // Use parsed tool name (as it appears in code) + parsedToolName := parseToolName(originalToolName) + autoExecutableTools = append(autoExecutableTools, parsedToolName) + } + + // Add to map if there are auto-executable tools + if len(autoExecutableTools) > 0 { + allowedTools[clientName] = autoExecutableTools + } + } + + return allClientNames, allowedTools +} diff --git a/core/mcp/agent_adaptors.go b/core/mcp/agent_adaptors.go new file mode 100644 index 000000000..d6118cbf7 --- /dev/null +++ b/core/mcp/agent_adaptors.go @@ -0,0 +1,559 @@ +package mcp + +import ( + "context" + "fmt" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +// agentAPIAdapter defines the interface for API-specific operations in agent mode. +// This adapter pattern allows the agent execution logic to work with both Chat Completions +// and Responses APIs without requiring API-specific code in the agent loop. +// +// The adapter handles format conversions at the boundaries: +// - Responses API requests/responses are converted to/from Chat API format +// - Tool calls are extracted in Chat format for uniform processing +// - Results are converted back to the original API format for the response +// +// This design ensures that: +// 1. Tool execution logic is format-agnostic +// 2. Both APIs have feature parity +// 3. Conversions are localized to adapters +// 4. The agent loop remains API-neutral +type agentAPIAdapter interface { + // Extract conversation history from the original request + getConversationHistory() []interface{} + + // Get original request + getOriginalRequest() interface{} + + // Get initial response + getInitialResponse() interface{} + + // Check if response has tool calls + hasToolCalls(response interface{}) bool + + // Extract tool calls from response. + // For Chat API: Returns tool calls directly from the response. + // For Responses API: Converts ResponsesMessage tool calls to ChatAssistantMessageToolCall for processing. + extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall + + // Add assistant message with tool calls to conversation + addAssistantMessage(conversation []interface{}, response interface{}) []interface{} + + // Add tool results to conversation. + // For Chat API: Adds ChatMessage results directly. + // For Responses API: Converts ChatMessage results to ResponsesMessage via ToResponsesToolMessage(). + addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} + + // Create new request with updated conversation + createNewRequest(conversation []interface{}) interface{} + + // Make LLM call + makeLLMCall(ctx context.Context, request interface{}) (interface{}, *schemas.BifrostError) + + // Create response with executed tools and non-auto-executable calls + createResponseWithExecutedTools( + response interface{}, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, + ) interface{} +} + +// chatAPIAdapter implements agentAPIAdapter for Chat API +type chatAPIAdapter struct { + originalReq *schemas.BifrostChatRequest + initialResponse *schemas.BifrostChatResponse + makeReq func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) +} + +// responsesAPIAdapter implements agentAPIAdapter for Responses API. +// It enables the agent mode execution loop to work with Responses API requests and responses +// by handling format conversions transparently. +// +// Key conversions performed: +// - extractToolCalls(): Converts ResponsesMessage tool calls to ChatAssistantMessageToolCall +// via BifrostResponsesResponse.ToBifrostChatResponse() and existing extraction logic +// - addToolResults(): Converts ChatMessage tool results back to ResponsesMessage +// via ChatMessage.ToResponsesMessages() and ToResponsesToolMessage() +// - createNewRequest(): Builds a new BifrostResponsesRequest from converted conversation +// - createResponseWithExecutedTools(): Creates a Responses response with results and pending tools +// +// This adapter enables full feature parity between Chat Completions and Responses APIs +// for tool execution in agent mode. +type responsesAPIAdapter struct { + originalReq *schemas.BifrostResponsesRequest + initialResponse *schemas.BifrostResponsesResponse + makeReq func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) +} + +// Chat API adapter implementations +func (c *chatAPIAdapter) getConversationHistory() []interface{} { + history := make([]interface{}, 0) + if c.originalReq.Input != nil { + for _, msg := range c.originalReq.Input { + history = append(history, msg) + } + } + return history +} + +func (c *chatAPIAdapter) getOriginalRequest() interface{} { + return c.originalReq +} + +func (c *chatAPIAdapter) getInitialResponse() interface{} { + return c.initialResponse +} + +func (c *chatAPIAdapter) hasToolCalls(response interface{}) bool { + chatResponse := response.(*schemas.BifrostChatResponse) + return hasToolCallsForChatResponse(chatResponse) +} + +func (c *chatAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall { + chatResponse := response.(*schemas.BifrostChatResponse) + return extractToolCalls(chatResponse) +} + +func (c *chatAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} { + chatResponse := response.(*schemas.BifrostChatResponse) + for _, choice := range chatResponse.Choices { + if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil { + conversation = append(conversation, *choice.ChatNonStreamResponseChoice.Message) + } + } + return conversation +} + +func (c *chatAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} { + for _, toolResult := range toolResults { + conversation = append(conversation, *toolResult) + } + return conversation +} + +func (c *chatAPIAdapter) createNewRequest(conversation []interface{}) interface{} { + // Convert conversation back to ChatMessage slice + chatMessages := make([]schemas.ChatMessage, 0, len(conversation)) + for _, msg := range conversation { + chatMessages = append(chatMessages, msg.(schemas.ChatMessage)) + } + + return &schemas.BifrostChatRequest{ + Provider: c.originalReq.Provider, + Model: c.originalReq.Model, + Fallbacks: c.originalReq.Fallbacks, + Params: c.originalReq.Params, + Input: chatMessages, + } +} + +func (c *chatAPIAdapter) makeLLMCall(ctx context.Context, request interface{}) (interface{}, *schemas.BifrostError) { + chatRequest := request.(*schemas.BifrostChatRequest) + return c.makeReq(ctx, chatRequest) +} + +func (c *chatAPIAdapter) createResponseWithExecutedTools( + response interface{}, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) interface{} { + chatResponse := response.(*schemas.BifrostChatResponse) + return createChatResponseWithExecutedToolsAndNonAutoExecutableCalls( + chatResponse, + executedToolResults, + executedToolCalls, + nonAutoExecutableToolCalls, + ) +} + +// createChatResponseWithExecutedToolsAndNonAutoExecutableCalls creates a chat response +// that includes executed tool results and non-auto-executable tool calls. The response +// contains a formatted text summary of executed tool results and includes the non-auto-executable +// tool calls for the caller to handle. The finish reason is set to "stop" to prevent +// further agent loop iterations. +// +// Parameters: +// - originalResponse: The original chat response to copy metadata from +// - executedToolResults: List of tool execution results from auto-executable tools +// - executedToolCalls: List of tool calls that were executed +// - nonAutoExecutableToolCalls: List of tool calls that require manual execution +// +// Returns: +// - *schemas.BifrostChatResponse: A new chat response with executed results and pending tool calls +func createChatResponseWithExecutedToolsAndNonAutoExecutableCalls( + originalResponse *schemas.BifrostChatResponse, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) *schemas.BifrostChatResponse { + // Start with a copy of the original response metadata + response := &schemas.BifrostChatResponse{ + ID: originalResponse.ID, + Object: originalResponse.Object, + Created: originalResponse.Created, + Model: originalResponse.Model, + Choices: make([]schemas.BifrostResponseChoice, 0), + ServiceTier: originalResponse.ServiceTier, + SystemFingerprint: originalResponse.SystemFingerprint, + Usage: originalResponse.Usage, + ExtraFields: originalResponse.ExtraFields, + SearchResults: originalResponse.SearchResults, + Videos: originalResponse.Videos, + Citations: originalResponse.Citations, + } + + // Build a map from tool call ID to tool name for easy lookup + toolCallIDToName := make(map[string]string) + for _, toolCall := range executedToolCalls { + if toolCall.ID != nil && toolCall.Function.Name != nil { + toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name + } + } + + // Build content text showing executed tool results + var contentText string + if len(executedToolResults) > 0 { + // Format tool results as JSON-like structure + toolResultsMap := make(map[string]interface{}) + for _, toolResult := range executedToolResults { + // Get tool name from tool call ID mapping + var toolName string + if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil { + toolCallID := *toolResult.ChatToolMessage.ToolCallID + if name, ok := toolCallIDToName[toolCallID]; ok { + toolName = name + } else { + toolName = toolCallID // Fallback to tool call ID if name not found + } + } else { + toolName = "unknown_tool" + } + + // Extract output from tool result + var output interface{} + if toolResult.Content != nil { + if toolResult.Content.ContentStr != nil { + output = *toolResult.Content.ContentStr + } else if toolResult.Content.ContentBlocks != nil { + // Convert content blocks to a readable format + blocks := make([]map[string]interface{}, 0) + for _, block := range toolResult.Content.ContentBlocks { + blockMap := make(map[string]interface{}) + blockMap["type"] = string(block.Type) + if block.Text != nil { + blockMap["text"] = *block.Text + } + blocks = append(blocks, blockMap) + } + output = blocks + } + } + toolResultsMap[toolName] = output + } + + // Convert to JSON string for display + jsonBytes, err := sonic.Marshal(toolResultsMap) + if err != nil { + // Fallback to simple string representation + contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap) + } else { + contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes)) + } + } else { + contentText = "Now I shall call these tools next..." + } + + // Create content with the formatted text + content := &schemas.ChatMessageContent{ + ContentStr: &contentText, + } + + // Determine finish reason + // Note: We set finish_reason to "stop" (not "tool_calls") for non-auto-executable tools + // to prevent the agent loop from retrying. The tool calls are still included in the response + // for the caller to handle, but setting finish_reason to "stop" ensures hasToolCalls returns false + // and the agent loop exits properly. + finishReason := "stop" + + // Create a single choice with the formatted content and non-auto-executable tool calls + response.Choices = append(response.Choices, schemas.BifrostResponseChoice{ + Index: 0, + FinishReason: &finishReason, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: content, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: nonAutoExecutableToolCalls, + }, + }, + }, + }) + + return response +} + +// Responses API adapter implementations +func (r *responsesAPIAdapter) getConversationHistory() []interface{} { + history := make([]interface{}, 0) + if r.originalReq.Input != nil { + for _, msg := range r.originalReq.Input { + history = append(history, msg) + } + } + return history +} + +func (r *responsesAPIAdapter) getOriginalRequest() interface{} { + return r.originalReq +} + +func (r *responsesAPIAdapter) getInitialResponse() interface{} { + return r.initialResponse +} + +func (r *responsesAPIAdapter) hasToolCalls(response interface{}) bool { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + return hasToolCallsForResponsesResponse(responsesResponse) +} + +func (r *responsesAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + // Convert to Chat format and extract tool calls using existing logic + chatResponse := responsesResponse.ToBifrostChatResponse() + return extractToolCalls(chatResponse) +} + +func (r *responsesAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + for _, output := range responsesResponse.Output { + conversation = append(conversation, output) + } + return conversation +} + +func (r *responsesAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} { + for _, toolResult := range toolResults { + // Convert using existing converter + responsesMessages := toolResult.ToResponsesMessages() + for _, respMsg := range responsesMessages { + conversation = append(conversation, respMsg) + } + } + return conversation +} + +func (r *responsesAPIAdapter) createNewRequest(conversation []interface{}) interface{} { + // Convert conversation back to ResponsesMessage slice + responsesMessages := make([]schemas.ResponsesMessage, 0, len(conversation)) + for _, msg := range conversation { + responsesMessages = append(responsesMessages, msg.(schemas.ResponsesMessage)) + } + + return &schemas.BifrostResponsesRequest{ + Provider: r.originalReq.Provider, + Model: r.originalReq.Model, + Fallbacks: r.originalReq.Fallbacks, + Params: r.originalReq.Params, + Input: responsesMessages, + } +} + +func (r *responsesAPIAdapter) makeLLMCall(ctx context.Context, request interface{}) (interface{}, *schemas.BifrostError) { + responsesRequest := request.(*schemas.BifrostResponsesRequest) + return r.makeReq(ctx, responsesRequest) +} + +func (r *responsesAPIAdapter) createResponseWithExecutedTools( + response interface{}, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) interface{} { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + + // Create response with executed tools directly on Responses schema + return createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls( + responsesResponse, + executedToolResults, + executedToolCalls, + nonAutoExecutableToolCalls, + ) +} + +// createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls creates a responses response +// that includes executed tool results and non-auto-executable tool calls. The response +// contains a formatted text summary of executed tool results and includes the non-auto-executable +// tool calls for the caller to handle. All Response-specific fields are preserved. +// +// Parameters: +// - originalResponse: The original responses response to copy metadata from +// - executedToolResults: List of tool execution results from auto-executable tools +// - executedToolCalls: List of tool calls that were executed +// - nonAutoExecutableToolCalls: List of tool calls that require manual execution +// +// Returns: +// - *schemas.BifrostResponsesResponse: A new responses response with executed results and pending tool calls +func createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls( + originalResponse *schemas.BifrostResponsesResponse, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) *schemas.BifrostResponsesResponse { + // Start with a copy of the original response, preserving all Response-specific fields + response := &schemas.BifrostResponsesResponse{ + ID: originalResponse.ID, + Background: originalResponse.Background, + Conversation: originalResponse.Conversation, + CreatedAt: originalResponse.CreatedAt, + Error: originalResponse.Error, + Include: originalResponse.Include, + IncompleteDetails: originalResponse.IncompleteDetails, + Instructions: originalResponse.Instructions, + MaxOutputTokens: originalResponse.MaxOutputTokens, + MaxToolCalls: originalResponse.MaxToolCalls, + Metadata: originalResponse.Metadata, + ParallelToolCalls: originalResponse.ParallelToolCalls, + PreviousResponseID: originalResponse.PreviousResponseID, + Prompt: originalResponse.Prompt, + PromptCacheKey: originalResponse.PromptCacheKey, + Reasoning: originalResponse.Reasoning, + SafetyIdentifier: originalResponse.SafetyIdentifier, + ServiceTier: originalResponse.ServiceTier, + StreamOptions: originalResponse.StreamOptions, + Store: originalResponse.Store, + Temperature: originalResponse.Temperature, + Text: originalResponse.Text, + TopLogProbs: originalResponse.TopLogProbs, + TopP: originalResponse.TopP, + ToolChoice: originalResponse.ToolChoice, + Tools: originalResponse.Tools, + Truncation: originalResponse.Truncation, + Usage: originalResponse.Usage, + ExtraFields: originalResponse.ExtraFields, + // Perplexity-specific fields + SearchResults: originalResponse.SearchResults, + Videos: originalResponse.Videos, + Citations: originalResponse.Citations, + Output: make([]schemas.ResponsesMessage, 0), + } + + // Build a map from tool call ID to tool name for easy lookup + toolCallIDToName := make(map[string]string) + for _, toolCall := range executedToolCalls { + if toolCall.ID != nil && toolCall.Function.Name != nil { + toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name + } + } + + // Build content text showing executed tool results + var contentText string + if len(executedToolResults) > 0 { + // Format tool results as JSON-like structure + toolResultsMap := make(map[string]interface{}) + for _, toolResult := range executedToolResults { + // Get tool name from tool call ID mapping + var toolName string + if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil { + toolCallID := *toolResult.ChatToolMessage.ToolCallID + if name, ok := toolCallIDToName[toolCallID]; ok { + toolName = name + } else { + toolName = toolCallID // Fallback to tool call ID if name not found + } + } else { + toolName = "unknown_tool" + } + + // Extract output from tool result + var output interface{} + if toolResult.Content != nil { + if toolResult.Content.ContentStr != nil { + output = *toolResult.Content.ContentStr + } else if toolResult.Content.ContentBlocks != nil { + // Convert content blocks to a readable format + blocks := make([]map[string]interface{}, 0) + for _, block := range toolResult.Content.ContentBlocks { + blockMap := make(map[string]interface{}) + blockMap["type"] = string(block.Type) + if block.Text != nil { + blockMap["text"] = *block.Text + } + blocks = append(blocks, blockMap) + } + output = blocks + } + } + toolResultsMap[toolName] = output + } + + // Convert to JSON string for display + jsonBytes, err := sonic.Marshal(toolResultsMap) + if err != nil { + // Fallback to simple string representation + contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap) + } else { + contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes)) + } + } else { + contentText = "Now I shall call these tools next..." + } + + // Create assistant message with the formatted text content + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + assistantMessage := schemas.ResponsesMessage{ + Type: &messageType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &contentText, + }, + }, + }, + } + response.Output = append(response.Output, assistantMessage) + + // Add non-auto-executable tool calls as separate function_call messages + for _, toolCall := range nonAutoExecutableToolCalls { + functionCallType := schemas.ResponsesMessageTypeFunctionCall + assistantRole := schemas.ResponsesInputMessageRoleAssistant + + var callID *string + if toolCall.ID != nil && *toolCall.ID != "" { + callID = toolCall.ID + } + + var namePtr *string + if toolCall.Function.Name != nil && *toolCall.Function.Name != "" { + namePtr = toolCall.Function.Name + } + + var argumentsPtr *string + if toolCall.Function.Arguments != "" { + argumentsPtr = &toolCall.Function.Arguments + } + + toolCallMessage := schemas.ResponsesMessage{ + Type: &functionCallType, + Role: &assistantRole, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: callID, + Name: namePtr, + Arguments: argumentsPtr, + }, + } + + response.Output = append(response.Output, toolCallMessage) + } + + return response +} diff --git a/core/mcp/agent_test.go b/core/mcp/agent_test.go new file mode 100644 index 000000000..86a24e6cf --- /dev/null +++ b/core/mcp/agent_test.go @@ -0,0 +1,719 @@ +package mcp + +import ( + "context" + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +// MockLLMCaller implements schemas.BifrostLLMCaller for testing +type MockLLMCaller struct { + chatResponses []*schemas.BifrostChatResponse + responsesResponses []*schemas.BifrostResponsesResponse + chatCallCount int + responsesCallCount int +} + +func (m *MockLLMCaller) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + if m.chatCallCount >= len(m.chatResponses) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "no more mock chat responses available", + }, + } + } + + response := m.chatResponses[m.chatCallCount] + m.chatCallCount++ + return response, nil +} + +func (m *MockLLMCaller) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + if m.responsesCallCount >= len(m.responsesResponses) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "no more mock responses api responses available", + }, + } + } + + response := m.responsesResponses[m.responsesCallCount] + m.responsesCallCount++ + return response, nil +} + +// MockLogger implements schemas.Logger for testing +type MockLogger struct{} + +func (m *MockLogger) Debug(msg string, args ...any) {} +func (m *MockLogger) Info(msg string, args ...any) {} +func (m *MockLogger) Warn(msg string, args ...any) {} +func (m *MockLogger) Error(msg string, args ...any) {} +func (m *MockLogger) Fatal(msg string, args ...any) {} +func (m *MockLogger) SetLevel(level schemas.LogLevel) {} +func (m *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {} + +// MockClientManager implements ClientManager for testing +type MockClientManager struct{} + +func (m *MockClientManager) GetClientForTool(toolName string) *schemas.MCPClientState { + return nil // Return nil to simulate no client found +} + +func (m *MockClientManager) GetClientByName(clientName string) *schemas.MCPClientState { + return nil +} + +func (m *MockClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { + return make(map[string][]schemas.ChatTool) +} + +func TestHasToolCallsForChatResponse(t *testing.T) { + // Test nil response + if hasToolCallsForChatResponse(nil) { + t.Error("Should return false for nil response") + } + + // Test empty choices + emptyResponse := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{}, + } + if hasToolCallsForChatResponse(emptyResponse) { + t.Error("Should return false for response with empty choices") + } + + // Test response with tool_calls finish reason + toolCallsResponse := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("tool_calls"), + }, + }, + } + if !hasToolCallsForChatResponse(toolCallsResponse) { + t.Error("Should return true for response with tool_calls finish reason") + } + + // Test response with actual tool calls + responseWithToolCalls := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("test_tool"), + }, + }, + }, + }, + }, + }, + }, + }, + } + if !hasToolCallsForChatResponse(responseWithToolCalls) { + t.Error("Should return true for response with tool calls in message") + } + + // Test response with stop finish reason (should return false even with tool calls) + responseWithStopReason := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("stop"), + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("test_tool"), + }, + }, + }, + }, + }, + }, + }, + }, + } + if hasToolCallsForChatResponse(responseWithStopReason) { + t.Error("Should return false for response with stop finish reason even with tool calls") + } +} + +func TestExtractToolCalls(t *testing.T) { + // Test response without tool calls + responseNoTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("stop"), + }, + }, + } + + toolCalls := extractToolCalls(responseNoTools) + if len(toolCalls) != 0 { + t.Error("Should return empty slice for response without tool calls") + } + + // Test response with tool calls + expectedToolCalls := []schemas.ChatAssistantMessageToolCall{ + { + ID: schemas.Ptr("call_123"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("test_tool"), + Arguments: `{"param": "value"}`, + }, + }, + } + + responseWithTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: expectedToolCalls, + }, + }, + }, + }, + }, + } + + actualToolCalls := extractToolCalls(responseWithTools) + if len(actualToolCalls) != 1 { + t.Errorf("Expected 1 tool call, got %d", len(actualToolCalls)) + } + + if actualToolCalls[0].Function.Name == nil || *actualToolCalls[0].Function.Name != "test_tool" { + t.Error("Tool call name mismatch") + } +} + +func TestExecuteAgentForChatRequest(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Test with response that has no tool calls - should return immediately + responseNoTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("stop"), + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Hello, how can I help you?"), + }, + }, + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return llmCaller.ChatCompletionRequest(ctx, req) + } + originalReq := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + } + + ctx := context.Background() + + result, err := ExecuteAgentForChatRequest(&ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{}) + if err != nil { + t.Errorf("Expected no error for response without tool calls, got: %v", err) + } + if result != responseNoTools { + t.Error("Expected same response to be returned for response without tool calls") + } +} + +func TestExecuteAgentForChatRequest_WithNonAutoExecutableTools(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Create a response with tool calls that will NOT be auto-executed + responseWithNonAutoTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("tool_calls"), + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("I need to call a tool"), + }, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + ID: schemas.Ptr("call_123"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("non_auto_executable_tool"), + Arguments: `{"param": "value"}`, + }, + }, + }, + }, + }, + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return llmCaller.ChatCompletionRequest(ctx, req) + } + originalReq := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Test message"), + }, + }, + }, + } + + ctx := context.Background() + + // Execute agent mode - should return immediately with non-auto-executable tools + result, err := ExecuteAgentForChatRequest(&ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{}) + + // Should not return error for non-auto-executable tools + if err != nil { + t.Errorf("Expected no error for non-auto-executable tools, got: %v", err) + } + + // Should return a response with the non-auto-executable tool calls + if result == nil { + t.Error("Expected result to be returned for non-auto-executable tools") + } + + // Verify that no LLM calls were made (since tools are non-auto-executable) + if llmCaller.chatCallCount != 0 { + t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.chatCallCount) + } +} + +func TestHasToolCallsForResponsesResponse(t *testing.T) { + // Test nil response + if hasToolCallsForResponsesResponse(nil) { + t.Error("Should return false for nil response") + } + + // Test empty output + emptyResponse := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{}, + } + if hasToolCallsForResponsesResponse(emptyResponse) { + t.Error("Should return false for response with empty output") + } + + // Test response with function call + responseWithFunctionCall := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call_123"), + Name: schemas.Ptr("test_tool"), + }, + }, + }, + } + if !hasToolCallsForResponsesResponse(responseWithFunctionCall) { + t.Error("Should return true for response with function call") + } + + // Test response with function call but no ResponsesToolMessage + responseWithoutToolMessage := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + // No ResponsesToolMessage + }, + }, + } + if hasToolCallsForResponsesResponse(responseWithoutToolMessage) { + t.Error("Should return false for response with function call type but no ResponsesToolMessage") + } + + // Test response with regular message + responseWithRegularMessage := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + } + if hasToolCallsForResponsesResponse(responseWithRegularMessage) { + t.Error("Should return false for response with regular message") + } +} + +func TestExecuteAgentForResponsesRequest(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Test with response that has no tool calls - should return immediately + responseNoTools := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello, how can I help you?"), + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return llmCaller.ResponsesRequest(ctx, req) + } + originalReq := &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + } + + ctx := context.Background() + + result, err := ExecuteAgentForResponsesRequest(&ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{}) + if err != nil { + t.Errorf("Expected no error for response without tool calls, got: %v", err) + } + if result != responseNoTools { + t.Error("Expected same response to be returned for response without tool calls") + } +} + +func TestExecuteAgentForResponsesRequest_WithNonAutoExecutableTools(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Create a response with tool calls that will NOT be auto-executed + responseWithNonAutoTools := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call_123"), + Name: schemas.Ptr("non_auto_executable_tool"), + Arguments: schemas.Ptr(`{"param": "value"}`), + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return llmCaller.ResponsesRequest(ctx, req) + } + originalReq := &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Test message"), + }, + }, + }, + } + + ctx := context.Background() + + // Execute agent mode - should return immediately with non-auto-executable tools + result, err := ExecuteAgentForResponsesRequest(&ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{}) + + // Should not return error for non-auto-executable tools + if err != nil { + t.Errorf("Expected no error for non-auto-executable tools, got: %v", err) + } + + // Should return a response with the non-auto-executable tool calls + if result == nil { + t.Error("Expected result to be returned for non-auto-executable tools") + } + + // Verify that no LLM calls were made (since tools are non-auto-executable) + if llmCaller.responsesCallCount != 0 { + t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.responsesCallCount) + } +} + +// ============================================================================ +// CONVERTER TESTS (Phase 2) +// ============================================================================ + +// TestResponsesToolMessageToChatAssistantMessageToolCall tests conversion of Responses tool message to Chat tool call +func TestResponsesToolMessageToChatAssistantMessageToolCall(t *testing.T) { + // Test with valid tool message + responsesToolMsg := &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call-123"), + Name: schemas.Ptr("calculate"), + Arguments: schemas.Ptr("{\"x\": 10, \"y\": 20}"), + } + + chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall() + + if chatToolCall == nil { + t.Fatal("Expected non-nil ChatAssistantMessageToolCall") + } + + if chatToolCall.Type == nil || *chatToolCall.Type != "function" { + t.Errorf("Expected Type 'function', got %v", chatToolCall.Type) + } + + if chatToolCall.Function.Name == nil || *chatToolCall.Function.Name != "calculate" { + t.Errorf("Expected Name 'calculate', got %v", chatToolCall.Function.Name) + } + + if chatToolCall.Function.Arguments != `{"x": 10, "y": 20}` { + t.Errorf("Expected Arguments '{\"x\": 10, \"y\": 20}', got %s", chatToolCall.Function.Arguments) + } +} + +// TestResponsesToolMessageToChatAssistantMessageToolCall_Nil tests nil handling +func TestResponsesToolMessageToChatAssistantMessageToolCall_Nil(t *testing.T) { + responsesToolMsg := &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call-123"), + Name: schemas.Ptr("calculate"), + Arguments: nil, // Test nil Arguments case + } + + chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall() + if chatToolCall == nil { + t.Fatal("Expected non-nil ChatAssistantMessageToolCall") + } + + // Assert that nil Arguments produces a valid empty JSON object + if chatToolCall.Function.Arguments != "{}" { + t.Errorf("Expected Arguments '{}' for nil input, got %q", chatToolCall.Function.Arguments) + } + + // Verify it's valid JSON by attempting to unmarshal + var args map[string]interface{} + if err := json.Unmarshal([]byte(chatToolCall.Function.Arguments), &args); err != nil { + t.Errorf("Expected valid JSON, but unmarshaling failed: %v", err) + } +} + +// TestChatMessageToResponsesToolMessage tests conversion of Chat tool result to Responses tool message +func TestChatMessageToResponsesToolMessage(t *testing.T) { + // Test with valid chat tool message + chatMsg := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call-123"), + }, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Result: 30"), + }, + } + + responsesMsg := chatMsg.ToResponsesToolMessage() + + if responsesMsg == nil { + t.Fatal("Expected non-nil ResponsesMessage") + } + + if responsesMsg.Type == nil || *responsesMsg.Type != schemas.ResponsesMessageTypeFunctionCallOutput { + t.Errorf("Expected Type 'function_call_output', got %v", responsesMsg.Type) + } + + if responsesMsg.ResponsesToolMessage == nil { + t.Fatal("Expected non-nil ResponsesToolMessage") + } + + if responsesMsg.ResponsesToolMessage.CallID == nil || *responsesMsg.ResponsesToolMessage.CallID != "call-123" { + t.Errorf("Expected CallID 'call-123', got %v", responsesMsg.ResponsesToolMessage.CallID) + } + + if responsesMsg.ResponsesToolMessage.Output == nil { + t.Fatal("Expected non-nil Output") + } + + if responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr == nil { + t.Fatal("Expected non-nil ResponsesToolCallOutputStr") + } + + if *responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != "Result: 30" { + t.Errorf("Expected Output 'Result: 30', got %s", *responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr) + } +} + +// TestChatMessageToResponsesToolMessage_Nil tests nil handling +func TestChatMessageToResponsesToolMessage_Nil(t *testing.T) { + var chatMsg *schemas.ChatMessage + + responsesMsg := chatMsg.ToResponsesToolMessage() + + if responsesMsg != nil { + t.Errorf("Expected nil for nil input, got %v", responsesMsg) + } +} + +// TestChatMessageToResponsesToolMessage_NoToolMessage tests with non-tool message +func TestChatMessageToResponsesToolMessage_NoToolMessage(t *testing.T) { + // Chat message without ChatToolMessage + chatMsg := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + } + + responsesMsg := chatMsg.ToResponsesToolMessage() + + if responsesMsg != nil { + t.Errorf("Expected nil for non-tool message, got %v", responsesMsg) + } +} + +// ============================================================================ +// RESPONSES API TOOL CONVERSION TESTS (Phase 3) +// ============================================================================ + +// TestExecuteAgentForResponsesRequest_ConversionRoundTrip tests that tool calls survive format conversion +// This is a unit test of the conversion logic only, not full agent execution +func TestExecuteAgentForResponsesRequest_ConversionRoundTrip(t *testing.T) { + // Create a tool message in Responses format + responsesToolMsg := &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call-456"), + Name: schemas.Ptr("readToolFile"), + Arguments: schemas.Ptr("{\"file\": \"test.txt\"}"), + } + + // Step 1: Convert Responses format to Chat format + chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall() + + if chatToolCall == nil { + t.Fatal("Failed to convert Responses to Chat format") + } + + if *chatToolCall.ID != "call-456" { + t.Errorf("ID lost in conversion: expected 'call-456', got %s", *chatToolCall.ID) + } + + if *chatToolCall.Function.Name != "readToolFile" { + t.Errorf("Name lost in conversion: expected 'readToolFile', got %s", *chatToolCall.Function.Name) + } + + if chatToolCall.Function.Arguments != "{\"file\": \"test.txt\"}" { + t.Errorf("Arguments lost in conversion: expected '%s', got %s", + "{\"file\": \"test.txt\"}", chatToolCall.Function.Arguments) + } + + // Step 2: Simulate tool execution by creating a result message + chatResultMsg := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: chatToolCall.ID, + }, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("File contents here"), + }, + } + + // Step 3: Convert tool result back to Responses format + responsesResultMsg := chatResultMsg.ToResponsesToolMessage() + + if responsesResultMsg == nil { + t.Fatal("Failed to convert Chat result to Responses format") + } + + if responsesResultMsg.ResponsesToolMessage.CallID == nil { + t.Error("CallID lost in round-trip conversion") + } else if *responsesResultMsg.ResponsesToolMessage.CallID != "call-456" { + t.Errorf("CallID changed in round-trip: expected 'call-456', got %s", *responsesResultMsg.ResponsesToolMessage.CallID) + } + + // Verify output is preserved + if responsesResultMsg.ResponsesToolMessage.Output == nil { + t.Error("Output lost in conversion") + } else if responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr == nil { + t.Error("Output content lost in conversion") + } else if *responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != "File contents here" { + t.Errorf("Output content changed: expected 'File contents here', got %s", + *responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr) + } + + // Verify message type is correct + if responsesResultMsg.Type == nil || *responsesResultMsg.Type != schemas.ResponsesMessageTypeFunctionCallOutput { + t.Errorf("Expected message type 'function_call_output', got %v", responsesResultMsg.Type) + } +} + +// TestExecuteAgentForResponsesRequest_OutputStructured tests conversion with structured output blocks +func TestExecuteAgentForResponsesRequest_OutputStructured(t *testing.T) { + chatResultMsg := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call-789"), + }, + Content: &schemas.ChatMessageContent{ + ContentBlocks: []schemas.ChatContentBlock{ + { + Type: schemas.ChatContentBlockTypeText, + Text: schemas.Ptr("Block 1"), + }, + { + Type: schemas.ChatContentBlockTypeText, + Text: schemas.Ptr("Block 2"), + }, + }, + }, + } + + responsesMsg := chatResultMsg.ToResponsesToolMessage() + + if responsesMsg == nil { + t.Fatal("Expected non-nil ResponsesMessage for structured output") + } + + if responsesMsg.ResponsesToolMessage.Output == nil { + t.Fatal("Expected non-nil Output for structured content") + } + + if responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks == nil { + t.Error("Expected output blocks for structured content") + } else if len(responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) != 2 { + t.Errorf("Expected 2 output blocks, got %d", len(responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks)) + } +} diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go new file mode 100644 index 000000000..fe7392abb --- /dev/null +++ b/core/mcp/clientmanager.go @@ -0,0 +1,700 @@ +package mcp + +import ( + "context" + "fmt" + "maps" + "os" + "strings" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/maximhq/bifrost/core/schemas" +) + +// GetClients returns all MCP clients managed by the manager. +// +// Returns: +// - []*schemas.MCPClientState: List of all MCP clients +func (m *MCPManager) GetClients() []schemas.MCPClientState { + m.mu.RLock() + defer m.mu.RUnlock() + + clients := make([]schemas.MCPClientState, 0, len(m.clientMap)) + for _, client := range m.clientMap { + snapshot := *client + if client.ToolMap != nil { + snapshot.ToolMap = make(map[string]schemas.ChatTool, len(client.ToolMap)) + maps.Copy(snapshot.ToolMap, client.ToolMap) + } + clients = append(clients, snapshot) + } + + return clients +} + +// ReconnectClient attempts to reconnect an MCP client if it is disconnected. +// It validates that the client exists and then establishes a new connection using +// the client's existing configuration. +// +// Parameters: +// - id: ID of the client to reconnect +// +// Returns: +// - error: Any error that occurred during reconnection +func (m *MCPManager) ReconnectClient(id string) error { + m.mu.Lock() + client, ok := m.clientMap[id] + if !ok { + m.mu.Unlock() + return fmt.Errorf("client %s not found", id) + } + config := client.ExecutionConfig + m.mu.Unlock() + + // connectToMCPClient handles locking internally + err := m.connectToMCPClient(config) + if err != nil { + return fmt.Errorf("failed to connect to MCP client %s: %w", id, err) + } + + return nil +} + +// AddClient adds a new MCP client to the manager. +// It validates the client configuration and establishes a connection. +// If connection fails, the client entry is automatically cleaned up. +// +// Parameters: +// - config: MCP client configuration +// +// Returns: +// - error: Any error that occurred during client addition or connection +func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { + if err := validateMCPClientConfig(&config); err != nil { + return fmt.Errorf("invalid MCP client configuration: %w", err) + } + + // Make a copy of the config to use after unlocking + configCopy := config + + m.mu.Lock() + + if _, ok := m.clientMap[config.ID]; ok { + m.mu.Unlock() + return fmt.Errorf("client %s already exists", config.Name) + } + + // Create placeholder entry + m.clientMap[config.ID] = &schemas.MCPClientState{ + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + } + + // Temporarily unlock for the connection attempt + // This is to avoid deadlocks when the connection attempt is made + m.mu.Unlock() + + // Connect using the copied config + if err := m.connectToMCPClient(configCopy); err != nil { + // Re-lock to clean up the failed entry + m.mu.Lock() + delete(m.clientMap, config.ID) + m.mu.Unlock() + return fmt.Errorf("failed to connect to MCP client %s: %w", config.Name, err) + } + + return nil +} + +// RemoveClient removes an MCP client from the manager. +// It handles cleanup for all transport types (HTTP, STDIO, SSE). +// +// Parameters: +// - id: ID of the client to remove +func (m *MCPManager) RemoveClient(id string) error { + m.mu.Lock() + defer m.mu.Unlock() + + return m.removeClientUnsafe(id) +} + +// removeClientUnsafe removes an MCP client from the manager without acquiring locks. +// This is an internal method that should only be called when the caller already holds +// the appropriate lock. It handles cleanup for all transport types including cancellation +// of SSE contexts and closing of transport connections. +// +// Parameters: +// - id: ID of the client to remove +// +// Returns: +// - error: Any error that occurred during client removal +func (m *MCPManager) removeClientUnsafe(id string) error { + client, ok := m.clientMap[id] + if !ok { + return fmt.Errorf("client %s not found", id) + } + + logger.Info(fmt.Sprintf("%s Disconnecting MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name)) + + // Stop health monitoring for this client + m.healthMonitorManager.StopMonitoring(id) + + // Cancel SSE context if present (required for proper SSE cleanup) + if client.CancelFunc != nil { + client.CancelFunc() + client.CancelFunc = nil + } + + // Close the client transport connection + // This handles cleanup for all transport types (HTTP, STDIO, SSE) + if client.Conn != nil { + if err := client.Conn.Close(); err != nil { + logger.Error("%s Failed to close MCP server '%s': %v", MCPLogPrefix, client.ExecutionConfig.Name, err) + } + client.Conn = nil + } + + // Clear client tool map + client.ToolMap = make(map[string]schemas.ChatTool) + + delete(m.clientMap, id) + return nil +} + +// EditClient updates an existing MCP client's configuration and refreshes its tool list. +// It updates the client's execution config with new settings and retrieves updated tools +// from the MCP server if the client is connected. +// This method does not refresh the client's tool list. +// To refresh the client's tool list, use the ReconnectClient method. +// +// Parameters: +// - id: ID of the client to edit +// - updatedConfig: Updated client configuration with new settings +// +// Returns: +// - error: Any error that occurred during client update or tool retrieval +func (m *MCPManager) EditClient(id string, updatedConfig schemas.MCPClientConfig) error { + m.mu.Lock() + defer m.mu.Unlock() + + client, ok := m.clientMap[id] + if !ok { + return fmt.Errorf("client %s not found", id) + } + + if err := validateMCPClientName(updatedConfig.Name); err != nil { + return fmt.Errorf("invalid MCP client configuration: %w", err) + } + + // Update the client's execution config with new tool filters + config := client.ExecutionConfig + config.Name = updatedConfig.Name + config.IsCodeModeClient = updatedConfig.IsCodeModeClient + config.Headers = updatedConfig.Headers + config.ToolsToExecute = updatedConfig.ToolsToExecute + config.ToolsToAutoExecute = updatedConfig.ToolsToAutoExecute + + // Store the updated config + client.ExecutionConfig = config + return nil +} + +// registerTool registers a typed tool handler with the local MCP server. +// This is a convenience function that handles the conversion between typed Go +// handlers and the MCP protocol. +// +// Type Parameters: +// - T: The expected argument type for the tool (must be JSON-deserializable) +// +// Parameters: +// - name: Unique tool name +// - description: Human-readable tool description +// - handler: Typed function that handles tool execution +// - toolSchema: Bifrost tool schema for function calling +// +// Returns: +// - error: Any registration error +// +// Example: +// +// type EchoArgs struct { +// Message string `json:"message"` +// } +// +// err := bifrost.RegisterMCPTool("echo", "Echo a message", +// func(args EchoArgs) (string, error) { +// return args.Message, nil +// }, toolSchema) +func (m *MCPManager) RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error { + // Ensure local server is set up + if err := m.setupLocalHost(); err != nil { + return fmt.Errorf("failed to setup local host: %w", err) + } + + // Validate tool name + if strings.TrimSpace(name) == "" { + return fmt.Errorf("tool name is required") + } + if strings.Contains(name, "-") { + return fmt.Errorf("tool name cannot contain hyphens") + } + if strings.Contains(name, " ") { + return fmt.Errorf("tool name cannot contain spaces") + } + if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { + return fmt.Errorf("tool name cannot start with a number") + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Verify internal client exists + internalClient, ok := m.clientMap[BifrostMCPClientKey] + if !ok { + return fmt.Errorf("bifrost client not found") + } + + // Check if tool name already exists to prevent silent overwrites + if _, exists := internalClient.ToolMap[name]; exists { + return fmt.Errorf("tool '%s' is already registered", name) + } + + logger.Info(fmt.Sprintf("%s Registering typed tool: %s", MCPLogPrefix, name)) + + // Create MCP handler wrapper that converts between typed and MCP interfaces + mcpHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments from the request using the request's methods + args := request.GetArguments() + result, err := toolFunction(args) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Error: %s", err.Error())), nil + } + return mcp.NewToolResultText(result), nil + } + + // Register the tool with the local MCP server using AddTool + if m.server != nil { + tool := mcp.NewTool(name, mcp.WithDescription(description)) + m.server.AddTool(tool, mcpHandler) + } + + // Store tool definition for Bifrost integration + internalClient.ToolMap[name] = toolSchema + + return nil +} + +// ============================================================================ +// CONNECTION HELPER METHODS +// ============================================================================ + +// connectToMCPClient establishes a connection to an external MCP server and +// registers its available tools with the manager. +func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { + // First lock: Initialize or validate client entry + m.mu.Lock() + + // Initialize or validate client entry + if existingClient, exists := m.clientMap[config.ID]; exists { + // Client entry exists from config, check for existing connection, if it does then close + if existingClient.CancelFunc != nil { + existingClient.CancelFunc() + existingClient.CancelFunc = nil + } + if existingClient.Conn != nil { + existingClient.Conn.Close() + } + // Update connection type for this connection attempt + existingClient.ConnectionInfo.Type = config.ConnectionType + } + // Create new client entry with configuration + m.clientMap[config.ID] = &schemas.MCPClientState{ + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + ConnectionInfo: schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + }, + } + m.mu.Unlock() + + // Heavy operations performed outside lock + var externalClient *client.Client + var connectionInfo schemas.MCPClientConnectionInfo + var err error + + // Create appropriate transport based on connection type + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + externalClient, connectionInfo, err = m.createHTTPConnection(config) + case schemas.MCPConnectionTypeSTDIO: + externalClient, connectionInfo, err = m.createSTDIOConnection(config) + case schemas.MCPConnectionTypeSSE: + externalClient, connectionInfo, err = m.createSSEConnection(config) + case schemas.MCPConnectionTypeInProcess: + externalClient, connectionInfo, err = m.createInProcessConnection(config) + default: + return fmt.Errorf("unknown connection type: %s", config.ConnectionType) + } + + if err != nil { + return fmt.Errorf("failed to create connection: %w", err) + } + + // Initialize the external client with timeout + // For SSE connections, we need a long-lived context, for others we can use timeout + var ctx context.Context + var cancel context.CancelFunc + + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + // SSE connections need a long-lived context for the persistent stream + ctx, cancel = context.WithCancel(m.ctx) + // Don't defer cancel here - SSE needs the context to remain active + } else { + // Other connection types can use timeout context + ctx, cancel = context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + } + + // Start the transport first (required for STDIO and SSE clients) + if err := externalClient.Start(ctx); err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to start MCP client transport %s: %v", config.Name, err) + } + + // Create proper initialize request for external client + extInitRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s", config.Name), + Version: "1.0.0", + }, + }, + } + + _, err = externalClient.Initialize(ctx, extInitRequest) + if err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to initialize MCP client %s: %v", config.Name, err) + } + + // Retrieve tools from the external server (this also requires network I/O) + tools, err := retrieveExternalTools(ctx, externalClient, config.Name) + if err != nil { + logger.Warn(fmt.Sprintf("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err)) + // Continue with connection even if tool retrieval fails + tools = make(map[string]schemas.ChatTool) + } + + // Second lock: Update client with final connection details and tools + m.mu.Lock() + defer m.mu.Unlock() + + // Verify client still exists (could have been cleaned up during heavy operations) + if client, exists := m.clientMap[config.ID]; exists { + // Store the external client connection and details + client.Conn = externalClient + client.ConnectionInfo = connectionInfo + client.State = schemas.MCPConnectionStateConnected + + // Store cancel function for SSE connections to enable proper cleanup + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + client.CancelFunc = cancel + } + + // Store discovered tools + for toolName, tool := range tools { + client.ToolMap[toolName] = tool + } + + logger.Info(fmt.Sprintf("%s Connected to MCP server '%s'", MCPLogPrefix, config.Name)) + } else { + // Clean up resources before returning error: client was removed during connection setup + // Cancel SSE context if it was created + if config.ConnectionType == schemas.MCPConnectionTypeSSE && cancel != nil { + cancel() + } + // Close external client connection to prevent transport/goroutine leaks + if externalClient != nil { + if err := externalClient.Close(); err != nil { + logger.Warn(fmt.Sprintf("%s Failed to close external client during cleanup: %v", MCPLogPrefix, err)) + } + } + return fmt.Errorf("client %s was removed during connection setup", config.Name) + } + + // Register OnConnectionLost hook for SSE connections to detect idle timeouts + if config.ConnectionType == schemas.MCPConnectionTypeSSE && externalClient != nil { + externalClient.OnConnectionLost(func(err error) { + logger.Warn(fmt.Sprintf("%s SSE connection lost for MCP server '%s': %v", MCPLogPrefix, config.Name, err)) + // Update state to disconnected + m.mu.Lock() + if client, exists := m.clientMap[config.ID]; exists { + client.State = schemas.MCPConnectionStateDisconnected + } + m.mu.Unlock() + }) + } + + // Start health monitoring for the client + monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval) + m.healthMonitorManager.StartMonitoring(monitor) + + return nil +} + +// createHTTPConnection creates an HTTP-based MCP client connection without holding locks. +func (m *MCPManager) createHTTPConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("HTTP connection string is required") + } + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, + } + + // Create StreamableHTTP transport + httpTransport, err := transport.NewStreamableHTTP(*config.ConnectionString, transport.WithHTTPHeaders(config.Headers)) + if err != nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + client := client.NewClient(httpTransport) + + return client, connectionInfo, nil +} + +// createSTDIOConnection creates a STDIO-based MCP client connection without holding locks. +func (m *MCPManager) createSTDIOConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.StdioConfig == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("stdio config is required") + } + + // Prepare STDIO command info for display + cmdString := fmt.Sprintf("%s %s", config.StdioConfig.Command, strings.Join(config.StdioConfig.Args, " ")) + + // Check if environment variables are set + for _, env := range config.StdioConfig.Envs { + if os.Getenv(env) == "" { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) + } + } + + // Create STDIO transport + stdioTransport := transport.NewStdio( + config.StdioConfig.Command, + config.StdioConfig.Envs, + config.StdioConfig.Args..., + ) + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + StdioCommandString: &cmdString, + } + + client := client.NewClient(stdioTransport) + + // Return nil for cmd since mark3labs/mcp-go manages the process internally + return client, connectionInfo, nil +} + +// createSSEConnection creates a SSE-based MCP client connection without holding locks. +func (m *MCPManager) createSSEConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("SSE connection string is required") + } + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, // Reuse HTTPConnectionURL field for SSE URL display + } + + // Create SSE transport + sseTransport, err := transport.NewSSE(*config.ConnectionString, transport.WithHeaders(config.Headers)) + if err != nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create SSE transport: %w", err) + } + + client := client.NewClient(sseTransport) + + return client, connectionInfo, nil +} + +// createInProcessConnection creates an in-process MCP client connection without holding locks. +// This allows direct connection to an MCP server running in the same process, providing +// the lowest latency and highest performance for tool execution. +func (m *MCPManager) createInProcessConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.InProcessServer == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") + } + + // Create in-process client directly connected to the provided server + inProcessClient, err := client.NewInProcessClient(config.InProcessServer) + if err != nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create in-process client: %w", err) + } + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + } + + return inProcessClient, connectionInfo, nil +} + +// ============================================================================ +// LOCAL MCP SERVER AND CLIENT MANAGEMENT +// ============================================================================ + +// setupLocalHost initializes the local MCP server and client if not already running. +// This creates a STDIO-based server for local tool hosting and a corresponding client. +// This is called automatically when tools are registered or when the server is needed. +// +// Returns: +// - error: Any setup error +func (m *MCPManager) setupLocalHost() error { + // First check: fast path if already initialized + m.mu.Lock() + if m.server != nil && m.serverRunning { + m.mu.Unlock() + return nil + } + m.mu.Unlock() + + // Create server and client into local variables (outside lock to avoid + // holding lock during object creation, even though it's lightweight) + server, err := m.createLocalMCPServer() + if err != nil { + return fmt.Errorf("failed to create local MCP server: %w", err) + } + + client, err := m.createLocalMCPClient() + if err != nil { + return fmt.Errorf("failed to create local MCP client: %w", err) + } + + // Second check and assignment: hold lock for atomic check-and-set + m.mu.Lock() + // Double-check: another goroutine might have initialized while we were creating + if m.server != nil && m.serverRunning { + m.mu.Unlock() + return nil + } + + // Assign server and client atomically while holding the lock + m.server = server + m.clientMap[BifrostMCPClientKey] = client + m.mu.Unlock() + + // Start the server and initialize client connection + // (startLocalMCPServer already locks internally) + return m.startLocalMCPServer() +} + +// createLocalMCPServer creates a new local MCP server instance with STDIO transport. +// This server will host tools registered via RegisterTool function. +// +// Returns: +// - *server.MCPServer: Configured MCP server instance +// - error: Any creation error +func (m *MCPManager) createLocalMCPServer() (*server.MCPServer, error) { + // Create MCP server + mcpServer := server.NewMCPServer( + "Bifrost-MCP-Server", + "1.0.0", + server.WithToolCapabilities(true), + ) + + return mcpServer, nil +} + +// createLocalMCPClient creates a placeholder client entry for the local MCP server. +// The actual in-process client connection will be established in startLocalMCPServer. +// +// Returns: +// - *schemas.MCPClientState: Placeholder client for local server +// - error: Any creation error +func (m *MCPManager) createLocalMCPClient() (*schemas.MCPClientState, error) { + // Don't create the actual client connection here - it will be created + // after the server is ready using NewInProcessClient + return &schemas.MCPClientState{ + ExecutionConfig: schemas.MCPClientConfig{ + ID: BifrostMCPClientKey, + Name: BifrostMCPClientName, + ToolsToExecute: []string{"*"}, // Allow all tools for internal client + }, + ToolMap: make(map[string]schemas.ChatTool), + ConnectionInfo: schemas.MCPClientConnectionInfo{ + Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport + }, + }, nil +} + +// startLocalMCPServer creates an in-process connection between the local server and client. +// +// Returns: +// - error: Any startup error +func (m *MCPManager) startLocalMCPServer() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Check if server is already running + if m.server != nil && m.serverRunning { + return nil + } + + if m.server == nil { + return fmt.Errorf("server not initialized") + } + + // Create in-process client directly connected to the server + inProcessClient, err := client.NewInProcessClient(m.server) + if err != nil { + return fmt.Errorf("failed to create in-process MCP client: %w", err) + } + + // Update the client connection + clientEntry, ok := m.clientMap[BifrostMCPClientKey] + if !ok { + return fmt.Errorf("bifrost client not found") + } + clientEntry.Conn = inProcessClient + + // Initialize the in-process client + ctx, cancel := context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + + // Create proper initialize request with correct structure + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: BifrostMCPClientName, + Version: BifrostMCPVersion, + }, + }, + } + + _, err = inProcessClient.Initialize(ctx, initRequest) + if err != nil { + return fmt.Errorf("failed to initialize MCP client: %w", err) + } + + // Mark server as running + m.serverRunning = true + + return nil +} diff --git a/core/mcp/codemode_executecode.go b/core/mcp/codemode_executecode.go new file mode 100644 index 000000000..201933920 --- /dev/null +++ b/core/mcp/codemode_executecode.go @@ -0,0 +1,1035 @@ +package mcp + +import ( + "context" + "fmt" + "regexp" + "strings" + "time" + + "github.com/bytedance/sonic" + "github.com/clarkmcc/go-typescript" + "github.com/dop251/goja" + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +// toolBinding represents a tool binding for the VM +type toolBinding struct { + toolName string + clientName string +} + +// toolCallInfo represents a tool call extracted from code +type toolCallInfo struct { + serverName string + toolName string +} + +// ExecutionResult represents the result of code execution +type ExecutionResult struct { + Result interface{} `json:"result"` + Logs []string `json:"logs"` + Errors *ExecutionError `json:"errors,omitempty"` + Environment ExecutionEnvironment `json:"environment"` +} + +type ExecutionErrorType string + +const ( + ExecutionErrorTypeCompile ExecutionErrorType = "compile" + ExecutionErrorTypeTypescript ExecutionErrorType = "typescript" + ExecutionErrorTypeRuntime ExecutionErrorType = "runtime" +) + +// ExecutionError represents an error during code execution +type ExecutionError struct { + Kind ExecutionErrorType `json:"kind"` // "compile", "typescript", or "runtime" + Message string `json:"message"` + Hints []string `json:"hints"` +} + +// ExecutionEnvironment contains information about the execution environment +type ExecutionEnvironment struct { + ServerKeys []string `json:"serverKeys"` + ImportsStripped bool `json:"importsStripped"` + StrippedLines []int `json:"strippedLines"` + TypeScriptUsed bool `json:"typescriptUsed"` +} + +const ( + CodeModeLogPrefix = "[CODE MODE]" +) + +// createExecuteToolCodeTool creates the executeToolCode tool definition for code mode. +// This tool allows executing TypeScript code in a sandboxed VM with access to MCP server tools. +// +// Returns: +// - schemas.ChatTool: The tool definition for executing tool code +func (m *ToolsManager) createExecuteToolCodeTool() schemas.ChatTool { + executeToolCodeProps := schemas.OrderedMap{ + "code": map[string]interface{}{ + "type": "string", + "description": "TypeScript code to execute. The code will be transpiled to JavaScript and validated before execution. Import/export statements will be stripped. You can use async/await syntax for async operations. For simple use cases, directly return results. Check keys and value types only for debugging. Do not print entire outputs in console logs - only print structure (keys, types) when debugging. ALWAYS retry if code fails. Example (simple): const result = await serverName.toolName({arg: 'value'}); return result; Example (debugging): const result = await serverName.toolName({arg: 'value'}); const getStruct = (o, d=0) => d>2 ? '...' : o===null ? 'null' : Array.isArray(o) ? `Array[${o.length}]` : typeof o !== 'object' ? typeof o : Object.keys(o).reduce((a,k) => (a[k]=getStruct(o[k],d+1), a), {}); console.log('Structure:', getStruct(result)); return result;", + }, + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: ToolTypeExecuteToolCode, + Description: schemas.Ptr( + "Executes TypeScript code inside a sandboxed goja-based VM with access to all connected MCP servers' tools. " + + "TypeScript code is automatically transpiled to JavaScript and validated before execution, providing type checking and validation. " + + "All connected servers are exposed as global objects named after their configuration keys, and each server " + + "provides async (Promise-returning) functions for every tool available on that server. The canonical usage " + + "pattern is: const result = await .({ ...args }); Both and " + + "should be discovered using listToolFiles and readToolFile. " + + + "IMPORTANT WORKFLOW: Always follow this order — first use listToolFiles to see available servers and tools, " + + "then use readToolFile to understand the tool definitions and their parameters, and finally use executeToolCode " + + "to execute your code. Check listToolFiles whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + + + "LOGGING GUIDELINES: For simple use cases, you can directly return results without logging. Check for keys and value types only " + + "for debugging purposes when you need to understand the response structure. Do not print the entire output in console logs. " + + "When debugging, use console logs to print just the output structure to understand its type. For nested objects, use a recursive helper to show types at all levels. " + + "For example: const getStruct = (o, d=0) => d>2 ? '...' : o===null ? 'null' : Array.isArray(o) ? `Array[${o.length}]` : typeof o !== 'object' ? typeof o : Object.keys(o).reduce((a,k) => (a[k]=getStruct(o[k],d+1), a), {}); " + + "console.log('Structure:', getStruct(result)); Only print the entire data if absolutely necessary for debugging. " + + "This helps understand the response structure without cluttering the output with full object contents. " + + + "RETRY POLICY: ALWAYS retry if a code block fails. If execution produces an error or unexpected result, analyze the error, " + + "adjust your code accordingly for better results or debugging, and retry the execution. Do not give up after a single failure — iterate and improve your code until it succeeds. " + + + "The environment is intentionally minimal and has several constraints: " + + "• ES modules are not supported — any leading import/export statements are automatically stripped and imported symbols will not exist. " + + "• Browser and Node APIs such as fetch, XMLHttpRequest, axios, require, setTimeout, setInterval, window, and document do not exist. " + + "• async/await syntax is supported and automatically transpiled to Promise chains compatible with goja. " + + "• Using undefined server names or tool names will result in reference or function errors. " + + "• The VM does not emulate a browser or Node.js environment — no DOM, timers, modules, or network APIs are available. " + + "• Only ES5.1+ features supported by goja are guaranteed to work. " + + "• TypeScript type checking occurs during transpilation — type errors will prevent execution. " + + + "If you want a value returned from the code, write a top-level 'return '; otherwise the return value will be null. " + + "Console output (log, error, warn, info) is captured and returned. " + + "Long-running or blocked operations are interrupted via execution timeout. " + + "This tool is designed specifically for orchestrating MCP tool calls and lightweight TypeScript computation.", + ), + + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &executeToolCodeProps, + Required: []string{"code"}, + }, + }, + } +} + +// handleExecuteToolCode handles the executeToolCode tool call. +// It parses the code argument, executes it in a sandboxed VM, and formats the response +// with execution results, logs, errors, and environment information. +// +// Parameters: +// - ctx: Context for code execution +// - toolCall: The tool call request containing the TypeScript code to execute +// +// Returns: +// - *schemas.ChatMessage: A tool response message containing execution results +// - error: Any error that occurred during processing +func (m *ToolsManager) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + toolName := "unknown" + if toolCall.Function.Name != nil { + toolName = *toolCall.Function.Name + } + logger.Debug(fmt.Sprintf("%s Handling executeToolCode tool call: %s", CodeModeLogPrefix, toolName)) + + // Parse tool arguments + var arguments map[string]interface{} + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + logger.Debug(fmt.Sprintf("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)) + return nil, fmt.Errorf("failed to parse tool arguments: %v", err) + } + + code, ok := arguments["code"].(string) + if !ok || code == "" { + logger.Debug(fmt.Sprintf("%s Code parameter missing or empty", CodeModeLogPrefix)) + return nil, fmt.Errorf("code parameter is required and must be a non-empty string") + } + + logger.Debug(fmt.Sprintf("%s Starting code execution", CodeModeLogPrefix)) + result := m.executeCode(ctx, code) + logger.Debug(fmt.Sprintf("%s Code execution completed. Success: %v, Has errors: %v, Log count: %d", CodeModeLogPrefix, result.Errors == nil, result.Errors != nil, len(result.Logs))) + + // Format response text + var responseText string + var executionSuccess bool = true // Track if execution was successful (has data) + if result.Errors != nil { + logger.Debug(fmt.Sprintf("%s Formatting error response. Error kind: %s, Message length: %d, Hints count: %d", CodeModeLogPrefix, result.Errors.Kind, len(result.Errors.Message), len(result.Errors.Hints))) + logsText := "" + if len(result.Logs) > 0 { + logsText = fmt.Sprintf("\n\nConsole/Log Output:\n%s\n", + strings.Join(result.Logs, "\n")) + } + errorKindLabel := result.Errors.Kind + + responseText = fmt.Sprintf( + "Execution %s error:\n\n%s\n\nHints:\n%s%s\n\nEnvironment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", + errorKindLabel, + result.Errors.Message, + strings.Join(result.Errors.Hints, "\n"), + logsText, + strings.Join(result.Environment.ServerKeys, ", "), + map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], + map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped], + ) + if len(result.Environment.StrippedLines) > 0 { + strippedStr := make([]string, len(result.Environment.StrippedLines)) + for i, line := range result.Environment.StrippedLines { + strippedStr[i] = fmt.Sprintf("%d", line) + } + responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) + } + logger.Debug(fmt.Sprintf("%s Error response formatted. Response length: %d chars", CodeModeLogPrefix, len(responseText))) + } else { + // Success case - check if execution produced any data + hasLogs := len(result.Logs) > 0 + hasResult := result.Result != nil + logger.Debug(fmt.Sprintf("%s Formatting success response. Has logs: %v, Has result: %v", CodeModeLogPrefix, hasLogs, hasResult)) + + // If execution completed but produced no data (no logs, no return value), treat as failure + if !hasLogs && !hasResult { + executionSuccess = false + logger.Debug(fmt.Sprintf("%s Execution completed with no data (no logs, no result), marking as failure", CodeModeLogPrefix)) + hints := []string{ + "Add console.log() statements throughout your code to debug and see what's happening at each step", + "Ensure your code has a top-level return statement if you want to return a value", + "Check that your tool calls are actually executing and returning data", + "Verify that async operations (like await) are properly handled", + } + responseText = fmt.Sprintf( + "Execution completed but produced no data:\n\n"+ + "The code executed without errors but returned no output (no console logs and no return value).\n\n"+ + "Hints:\n%s\n\n"+ + "Environment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", + strings.Join(hints, "\n"), + strings.Join(result.Environment.ServerKeys, ", "), + map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], + map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped], + ) + if len(result.Environment.StrippedLines) > 0 { + strippedStr := make([]string, len(result.Environment.StrippedLines)) + for i, line := range result.Environment.StrippedLines { + strippedStr[i] = fmt.Sprintf("%d", line) + } + responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) + } + logger.Debug(fmt.Sprintf("%s No-data failure response formatted. Response length: %d chars", CodeModeLogPrefix, len(responseText))) + } else { + // Normal success case with data + if hasLogs { + responseText = fmt.Sprintf("Console output:\n%s\n\nExecution completed successfully.", + strings.Join(result.Logs, "\n")) + } else { + responseText = "Execution completed successfully." + } + if hasResult { + resultJSON, err := sonic.MarshalIndent(result.Result, "", " ") + if err == nil { + responseText += fmt.Sprintf("\nReturn value: %s", string(resultJSON)) + logger.Debug(fmt.Sprintf("%s Added return value to response (JSON length: %d chars)", CodeModeLogPrefix, len(resultJSON))) + } else { + logger.Debug(fmt.Sprintf("%s Failed to marshal result to JSON: %v", CodeModeLogPrefix, err)) + } + } + + // Add environment information for successful executions + responseText += fmt.Sprintf("\n\nEnvironment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", + strings.Join(result.Environment.ServerKeys, ", "), + map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], + map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped]) + if len(result.Environment.StrippedLines) > 0 { + strippedStr := make([]string, len(result.Environment.StrippedLines)) + for i, line := range result.Environment.StrippedLines { + strippedStr[i] = fmt.Sprintf("%d", line) + } + responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) + } + responseText += "\nNote: Browser APIs like fetch, setTimeout are not available. Use MCP tools for external interactions." + logger.Debug(fmt.Sprintf("%s Success response formatted. Response length: %d chars, Server keys: %v", CodeModeLogPrefix, len(responseText), result.Environment.ServerKeys)) + } + } + + logger.Debug(fmt.Sprintf("%s Returning tool response message. Execution success: %v", CodeModeLogPrefix, executionSuccess)) + return createToolResponseMessage(toolCall, responseText), nil +} + +// executeCode executes TypeScript code in a sandboxed VM with MCP tool bindings. +// It handles code preprocessing (stripping imports/exports), TypeScript transpilation, +// VM setup with tool bindings, and promise-based async execution with timeout handling. +// +// Parameters: +// - ctx: Context for code execution (used for timeout and tool access) +// - code: TypeScript code string to execute +// +// Returns: +// - ExecutionResult: Result containing execution output, logs, errors, and environment info +func (m *ToolsManager) executeCode(ctx context.Context, code string) ExecutionResult { + logs := []string{} + strippedLines := []int{} + + logger.Debug(fmt.Sprintf("%s Starting TypeScript code execution", CodeModeLogPrefix)) + + // Step 1: Convert literal \n escape sequences to actual newlines first + // This ensures multiline code and import/export stripping work correctly + codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") + + // Step 2: Strip import/export statements + cleanedCode, strippedLineNumbers := stripImportsAndExports(codeWithNewlines) + strippedLines = append(strippedLines, strippedLineNumbers...) + if len(strippedLineNumbers) > 0 { + logger.Debug(fmt.Sprintf("%s Stripped %d import/export lines", CodeModeLogPrefix, len(strippedLineNumbers))) + } + + // Step 3: Handle empty code after stripping (in case stripping made it empty) + trimmedCode := strings.TrimSpace(cleanedCode) + if trimmedCode == "" { + // Empty code should return null - return early without VM execution + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: nil, + Environment: ExecutionEnvironment{ + ServerKeys: []string{}, // Will be populated below if needed, but empty code doesn't need tools + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } + } + + // Step 4: Wrap code in async function for proper await transpilation + // TypeScript needs an async function context to properly transpile await expressions + // Check if code is already an async IIFE - if so, await it + trimmedLower := strings.ToLower(strings.TrimSpace(trimmedCode)) + isAsyncIIFE := strings.HasPrefix(trimmedLower, "(async") && strings.Contains(trimmedCode, ")()") + + var codeToTranspile string + if isAsyncIIFE { + // Code is already an async IIFE - await it to get the result + codeToTranspile = fmt.Sprintf("async function __execute__() {\nreturn await %s\n}", trimmedCode) + } else { + // Regular code - wrap in async function + codeToTranspile = fmt.Sprintf("async function __execute__() {\n%s\n}", trimmedCode) + } + + // Step 5: Transpile TypeScript to JavaScript with validation + // Configure TypeScript compiler to transpile async/await to Promise chains (ES5 compatible) + logger.Debug(fmt.Sprintf("%s Transpiling TypeScript code", CodeModeLogPrefix)) + compileOptions := map[string]interface{}{ + "target": "ES5", // Target ES5 for goja compatibility + "module": "None", // No module system + "lib": []string{}, // No lib (minimal environment) + "downlevelIteration": true, // Support async/await transpilation + } + jsCode, transpileErr := typescript.TranspileString(codeToTranspile, typescript.WithCompileOptions(compileOptions)) + if transpileErr != nil { + logger.Debug(fmt.Sprintf("%s TypeScript transpilation failed: %v", CodeModeLogPrefix, transpileErr)) + // Build bindings to get server keys for error hints + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + serverKeys := make([]string, 0, len(availableToolsPerClient)) + for clientName := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient { + continue + } + serverKeys = append(serverKeys, clientName) + } + + errorMessage := transpileErr.Error() + hints := generateTypeScriptErrorHints(errorMessage, serverKeys) + + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: &ExecutionError{ + Kind: ExecutionErrorTypeTypescript, + Message: fmt.Sprintf("TypeScript compilation error: %s", errorMessage), + Hints: hints, + }, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } + } + + logger.Debug(fmt.Sprintf("%s TypeScript transpiled successfully", CodeModeLogPrefix)) + + // Step 5: Create timeout context early so goroutines can use it + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + // Step 6: Build bindings for all connected servers + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + bindings := make(map[string]map[string]toolBinding) + serverKeys := make([]string, 0, len(availableToolsPerClient)) + + for clientName, tools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { + continue + } + serverKeys = append(serverKeys, clientName) + + toolFunctions := make(map[string]toolBinding) + + // Create a function for each tool + for _, tool := range tools { + if tool.Function == nil || tool.Function.Name == "" { + continue + } + + originalToolName := tool.Function.Name + // Parse tool name for property name compatibility (used as property name in the runtime) + parsedToolName := parseToolName(originalToolName) + + // Store tool binding + toolFunctions[parsedToolName] = toolBinding{ + toolName: originalToolName, + clientName: clientName, + } + } + + bindings[clientName] = toolFunctions + } + + if len(serverKeys) > 0 { + logger.Debug(fmt.Sprintf("%s Bound %d servers with tools", CodeModeLogPrefix, len(serverKeys))) + } + + // Step 7: Wrap transpiled code to execute the async function and return its result + // The transpiled code contains an async function __execute__() that we need to call + // Trim trailing newlines to avoid issues when wrapping + codeToWrap := strings.TrimRight(jsCode, "\n\r") + // Wrap in IIFE that calls the transpiled async function and returns the promise + wrappedCode := fmt.Sprintf("(function() {\n%s\nreturn __execute__();\n})()", codeToWrap) + + // Step 8: Create goja runtime + vm := goja.New() + + // Step 9: Set up thread-safe logging + appendLog := func(msg string) { + m.logMu.Lock() + defer m.logMu.Unlock() + logs = append(logs, msg) + } + + // Step 10: Set up console + consoleObj := vm.NewObject() + consoleObj.Set("log", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(message) + }) + consoleObj.Set("error", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(fmt.Sprintf("[ERROR] %s", message)) + }) + consoleObj.Set("warn", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(fmt.Sprintf("[WARN] %s", message)) + }) + consoleObj.Set("info", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(fmt.Sprintf("[INFO] %s", message)) + }) + vm.Set("console", consoleObj) + + // Step 11: Set up server bindings + for serverKey, tools := range bindings { + serverObj := vm.NewObject() + for toolName, binding := range tools { + // Capture variables for closure + toolNameFinal := binding.toolName + clientNameFinal := binding.clientName + + serverObj.Set(toolName, func(call goja.FunctionCall) goja.Value { + args := call.Argument(0).Export() + + // Convert args to map[string]interface{} + argsMap, ok := args.(map[string]interface{}) + if !ok { + logger.Debug(fmt.Sprintf("%s Invalid args type for %s.%s: expected object, got %T", + CodeModeLogPrefix, clientNameFinal, toolNameFinal, args)) + // Return rejected promise for invalid args + promise, _, reject := vm.NewPromise() + err := fmt.Errorf("expected object argument, got %T", args) + reject(vm.ToValue(err)) + return vm.ToValue(promise) + } + + // Create promise on VM goroutine (thread-safe) + promise, resolve, reject := vm.NewPromise() + + // Define result struct for channel communication + type toolResult struct { + result interface{} + err error + } + + // Create buffered channel for worker communication + resultChan := make(chan toolResult, 1) + + // Call tool asynchronously with timeout context and panic recovery + // Worker goroutine - NO VM calls allowed here + go func() { + defer func() { + if r := recover(); r != nil { + logger.Debug(fmt.Sprintf("%s Panic in tool call goroutine for %s.%s: %v", + CodeModeLogPrefix, clientNameFinal, toolNameFinal, r)) + // Send panic as error through channel (no VM calls in worker) + select { + case resultChan <- toolResult{nil, fmt.Errorf("tool call panic: %v", r)}: + case <-timeoutCtx.Done(): + // Context cancelled, ignore + } + } + }() + + // Check if context is already cancelled before starting + select { + case <-timeoutCtx.Done(): + // Send timeout error through channel (no VM calls in worker) + select { + case resultChan <- toolResult{nil, fmt.Errorf("execution timeout")}: + case <-timeoutCtx.Done(): + // Already cancelled, ignore + } + return + default: + } + + result, err := m.callMCPTool(timeoutCtx, clientNameFinal, toolNameFinal, argsMap, appendLog) + + // Check if context was cancelled during execution + select { + case <-timeoutCtx.Done(): + // Send timeout error through channel (no VM calls in worker) + select { + case resultChan <- toolResult{nil, fmt.Errorf("execution timeout")}: + case <-timeoutCtx.Done(): + // Already cancelled, ignore + } + return + default: + } + + // Send result through channel (no VM calls in worker) + select { + case resultChan <- toolResult{result, err}: + case <-timeoutCtx.Done(): + // Context cancelled, ignore + } + }() + + // Process result synchronously on VM goroutine to ensure thread safety + // This blocks the VM goroutine until the tool call completes, but ensures + // all VM operations (vm.ToValue, resolve, reject) happen on the correct thread + select { + case res := <-resultChan: + if res.err != nil { + logger.Debug(fmt.Sprintf("%s Tool call failed: %s.%s - %v", + CodeModeLogPrefix, clientNameFinal, toolNameFinal, res.err)) + reject(vm.ToValue(res.err)) + } else { + resolve(vm.ToValue(res.result)) + } + case <-timeoutCtx.Done(): + reject(vm.ToValue(fmt.Errorf("execution timeout"))) + } + + return vm.ToValue(promise) + }) + } + vm.Set(serverKey, serverObj) + } + + // Step 12: Set up environment info + envObj := vm.NewObject() + envObj.Set("serverKeys", serverKeys) + envObj.Set("version", "1.0.0") + vm.Set("__MCP_ENV__", envObj) + + // Step 13: Execute code with timeout + + // Set up interrupt handler + interruptDone := make(chan struct{}) + go func() { + select { + case <-timeoutCtx.Done(): + logger.Debug(fmt.Sprintf("%s Execution timeout reached", CodeModeLogPrefix)) + vm.Interrupt("execution timeout") + case <-interruptDone: + } + }() + + var result interface{} + var executionErr error + + func() { + defer close(interruptDone) + val, err := vm.RunString(wrappedCode) + if err != nil { + logger.Debug(fmt.Sprintf("%s VM execution error: %v", CodeModeLogPrefix, err)) + executionErr = err + return + } + + // Check if the result is a promise by checking its type + // First check if val is nil or undefined (these can't be converted to objects) + if val == nil || val == goja.Undefined() { + result = nil + return + } + + // Try to convert to object to check if it's a promise + // Use recover to safely handle null values that can't be converted to objects + var valObj *goja.Object + func() { + defer func() { + if r := recover(); r != nil { + // Value is null or can't be converted to object, just export it + valObj = nil + } + }() + valObj = val.ToObject(vm) + }() + + if valObj != nil { + // Check if it has a 'then' method (Promise-like) + if then := valObj.Get("then"); then != nil && then != goja.Undefined() { + // It's a promise, we need to await it + // Use buffered channels to prevent blocking if handlers are called after timeout + resultChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + + // Set up promise handlers + thenFunc, ok := goja.AssertFunction(then) + if ok { + // Call then with resolve and reject handlers + _, err := thenFunc(val, + vm.ToValue(func(res goja.Value) { + select { + case resultChan <- res.Export(): + case <-timeoutCtx.Done(): + // Timeout already occurred, ignore result + } + }), + vm.ToValue(func(err goja.Value) { + var errMsg string + if err == nil || err == goja.Undefined() { + errMsg = "unknown error" + } else { + // Try to get error message from Error object + if errObj := err.ToObject(vm); errObj != nil { + if msg := errObj.Get("message"); msg != nil && msg != goja.Undefined() { + errMsg = msg.String() + } else if name := errObj.Get("name"); name != nil && name != goja.Undefined() { + errMsg = name.String() + } else { + errMsg = err.String() + } + } else { + // Fallback to string conversion + errMsg = err.String() + } + } + select { + case errChan <- fmt.Errorf("%s", errMsg): + case <-timeoutCtx.Done(): + // Timeout already occurred, ignore error + } + }), + ) + if err != nil { + executionErr = err + return + } + + // Wait for result or error with timeout + select { + case res := <-resultChan: + result = res + case err := <-errChan: + logger.Debug(fmt.Sprintf("%s Promise rejected: %v", CodeModeLogPrefix, err)) + executionErr = err + case <-timeoutCtx.Done(): + logger.Debug(fmt.Sprintf("%s Promise timeout while waiting for result", CodeModeLogPrefix)) + executionErr = fmt.Errorf("execution timeout") + } + } else { + result = val.Export() + } + } else { + result = val.Export() + } + } else { + // Not an object (or null/undefined), just export the value + result = val.Export() + } + }() + + if executionErr != nil { + errorMessage := executionErr.Error() + hints := generateErrorHints(errorMessage, serverKeys) + logger.Debug(fmt.Sprintf("%s Execution failed: %s", CodeModeLogPrefix, errorMessage)) + + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: &ExecutionError{ + Kind: ExecutionErrorTypeRuntime, + Message: errorMessage, + Hints: hints, + }, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } + } + + logger.Debug(fmt.Sprintf("%s Execution completed successfully", CodeModeLogPrefix)) + return ExecutionResult{ + Result: result, + Logs: logs, + Errors: nil, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } +} + +// callMCPTool calls an MCP tool and returns the result. +// It locates the client by name, constructs the MCP tool call request, executes it +// with timeout handling, and parses the response as JSON or returns it as a string. +// +// Parameters: +// - ctx: Context for tool execution (used for timeout) +// - clientName: Name of the MCP client/server to call +// - toolName: Name of the tool to execute +// - args: Tool arguments as a map +// - appendLog: Function to append log messages during execution +// +// Returns: +// - interface{}: Parsed tool result (JSON object or string) +// - error: Any error that occurred during tool execution +func (m *ToolsManager) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { + // Get available tools per client + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + + // Find the client by name + tools, exists := availableToolsPerClient[clientName] + if !exists || len(tools) == 0 { + return nil, fmt.Errorf("client not found for server name: %s", clientName) + } + + // Get client using a tool from this client + // Find the first tool with a valid Function to use for client lookup + var client *schemas.MCPClientState + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name != "" { + client = m.clientManager.GetClientForTool(tool.Function.Name) + if client != nil { + break + } + } + } + + if client == nil { + return nil, fmt.Errorf("client not found for server name: %s", clientName) + } + + // Strip the client name prefix from tool name before calling MCP server + // The MCP server expects the original tool name, not the prefixed version + originalToolName := stripClientPrefix(toolName, clientName) + + // Call the tool via MCP client + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: originalToolName, + Arguments: args, + }, + } + + // Create timeout context + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + if callErr != nil { + logger.Debug(fmt.Sprintf("%s Tool call failed: %s.%s - %v", CodeModeLogPrefix, clientName, toolName, callErr)) + appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, toolName, callErr)) + return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, toolName, callErr) + } + + // Extract result + rawResult := extractTextFromMCPResponse(toolResponse, toolName) + + // Check if this is an error result (from NewToolResultError) + // Error results start with "Error: " prefix + if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { + errorMsg := after + logger.Debug(fmt.Sprintf("%s Tool returned error result: %s.%s - %s", CodeModeLogPrefix, clientName, toolName, errorMsg)) + appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, toolName, errorMsg)) + return nil, fmt.Errorf("%s", errorMsg) + } + + // Try to parse as JSON, otherwise use as string + var finalResult interface{} + if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { + // Not JSON, use as string + finalResult = rawResult + } + + // Log the result + resultStr := formatResultForLog(finalResult) + appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, toolName, resultStr)) + + return finalResult, nil +} + +// HELPER FUNCTIONS + +// formatResultForLog formats a result value for logging purposes. +// It attempts to marshal to JSON for structured output, falling back to string representation. +// +// Parameters: +// - result: The result value to format +// +// Returns: +// - string: Formatted string representation of the result +func formatResultForLog(result interface{}) string { + var resultStr string + if result == nil { + resultStr = "null" + } else if resultBytes, err := sonic.Marshal(result); err == nil { + resultStr = string(resultBytes) + } else { + resultStr = fmt.Sprintf("%v", result) + } + return resultStr +} + +// formatConsoleArgs formats console arguments for logging. +// It formats each argument as JSON if possible, otherwise uses string representation. +// +// Parameters: +// - args: Array of console arguments to format +// +// Returns: +// - string: Formatted string with all arguments joined by spaces +func formatConsoleArgs(args []interface{}) string { + parts := make([]string, len(args)) + for i, arg := range args { + if argBytes, err := sonic.MarshalIndent(arg, "", " "); err == nil { + parts[i] = string(argBytes) + } else { + parts[i] = fmt.Sprintf("%v", arg) + } + } + return strings.Join(parts, " ") +} + +// stripImportsAndExports strips import and export statements from code. +// It removes lines that start with import or export keywords and returns +// the cleaned code along with 1-based line numbers of stripped lines. +// +// Parameters: +// - code: Source code string to process +// +// Returns: +// - string: Code with import/export statements removed +// - []int: 1-based line numbers of stripped lines +func stripImportsAndExports(code string) (string, []int) { + lines := strings.Split(code, "\n") + keptLines := []string{} + strippedLineNumbers := []int{} + + importExportRegex := regexp.MustCompile(`^\s*(import|export)\b`) + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + + // Skip empty lines + if trimmed == "" { + keptLines = append(keptLines, line) + continue + } + + // Check if this is an import or export statement + isImportOrExport := importExportRegex.MatchString(line) + + if isImportOrExport { + strippedLineNumbers = append(strippedLineNumbers, i+1) // 1-based line numbers + continue // Skip import/export lines + } + + // Keep comment lines and all other non-import/export lines + keptLines = append(keptLines, line) + } + + return strings.Join(keptLines, "\n"), strippedLineNumbers +} + +// generateTypeScriptErrorHints generates helpful hints for TypeScript compilation errors. +// It analyzes the error message and provides context-specific guidance based on error patterns. +// +// Parameters: +// - errorMessage: The TypeScript compilation error message +// - serverKeys: List of available MCP server keys for context +// +// Returns: +// - []string: Array of helpful hint messages +func generateTypeScriptErrorHints(errorMessage string, serverKeys []string) []string { + hints := []string{} + + // TypeScript-specific error patterns + if strings.Contains(errorMessage, "Cannot find name") || strings.Contains(errorMessage, "is not defined") { + hints = append(hints, "TypeScript compilation error: undefined variable or identifier.") + hints = append(hints, "Check that all variables are properly declared and typed.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, "Use server keys to access MCP tools: .(args)") + } + } else if strings.Contains(errorMessage, "Type") && (strings.Contains(errorMessage, "is not assignable") || strings.Contains(errorMessage, "does not exist")) { + hints = append(hints, "TypeScript type error detected.") + hints = append(hints, "Check that variable types match their usage.") + hints = append(hints, "Ensure function arguments match the expected types.") + } else if strings.Contains(errorMessage, "Expected") { + hints = append(hints, "TypeScript syntax error detected.") + hints = append(hints, "Check for missing parentheses, brackets, or semicolons.") + hints = append(hints, "Ensure all code blocks are properly closed.") + } else if strings.Contains(errorMessage, "async") || strings.Contains(errorMessage, "await") { + hints = append(hints, "async/await syntax should be supported. If you see this error, it may be a TypeScript compilation issue.") + hints = append(hints, "Ensure async functions are properly declared: async function myFunction() { ... }") + hints = append(hints, "Example: const result = await serverName.toolName({...});") + } else { + hints = append(hints, "TypeScript compilation error detected.") + hints = append(hints, "Review the error message above for specific details.") + hints = append(hints, "Ensure your TypeScript code follows valid syntax and type rules.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + } + + return hints +} + +// generateErrorHints generates helpful hints based on runtime error messages. +// It analyzes common runtime error patterns (undefined variables, missing functions, etc.) +// and provides context-specific guidance including available server keys and usage examples. +// +// Parameters: +// - errorMessage: The runtime error message +// - serverKeys: List of available MCP server keys for context +// +// Returns: +// - []string: Array of helpful hint messages +func generateErrorHints(errorMessage string, serverKeys []string) []string { + hints := []string{} + + if strings.Contains(errorMessage, "is not defined") { + re := regexp.MustCompile(`(\w+)\s+is not defined`) + if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar := match[1] + + // Special handling for common browser/Node.js APIs + if undefinedVar == "fetch" { + hints = append(hints, "The 'fetch' API is not available in this runtime environment.") + hints = append(hints, "Instead of using fetch for HTTP requests, use the available MCP tools.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, fmt.Sprintf("Example: const result = await %s.({ url: 'https://example.com' });", serverKeys[0])) + } + hints = append(hints, "MCP tools handle HTTP requests, file operations, and other external interactions.") + return hints + } else if undefinedVar == "XMLHttpRequest" || undefinedVar == "axios" { + hints = append(hints, fmt.Sprintf("The '%s' API is not available in this runtime environment.", undefinedVar)) + hints = append(hints, "Use MCP tools instead for HTTP requests and external API calls.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + return hints + } else if undefinedVar == "setTimeout" || undefinedVar == "setInterval" { + hints = append(hints, fmt.Sprintf("The '%s' API is not available in this runtime environment.", undefinedVar)) + hints = append(hints, "This is a sandboxed environment focused on MCP tool interactions.") + hints = append(hints, "Use Promise chains with MCP tools instead of timing functions.") + return hints + } else if undefinedVar == "require" || undefinedVar == "import" { + hints = append(hints, "Module imports are not supported in this runtime environment.") + hints = append(hints, "Use the available MCP tools for external functionality.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + return hints + } + + // Generic undefined variable handling + hints = append(hints, fmt.Sprintf("Variable or identifier '%s' is not defined.", undefinedVar)) + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Use one of the available server keys as the object name: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, "Then access tools using: .(args)") + hints = append(hints, fmt.Sprintf("For example: const result = await %s.({ ... });", serverKeys[0])) + } + } + } else if strings.Contains(errorMessage, "is not a function") { + re := regexp.MustCompile(`(\w+(?:\.\w+)?)\s+is not a function`) + if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { + notFunction := match[1] + hints = append(hints, fmt.Sprintf("'%s' is not a function.", notFunction)) + hints = append(hints, "Ensure you're using the correct server key and tool name.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + hints = append(hints, "To see available tools for a server, use listToolFiles and readToolFile.") + } + } else if strings.Contains(errorMessage, "Cannot read property") || + strings.Contains(errorMessage, "Cannot read properties") || + strings.Contains(errorMessage, "is not an object") { + hints = append(hints, "You're trying to access a property that doesn't exist or is undefined.") + hints = append(hints, "The tool response structure might be different than expected.") + hints = append(hints, "Check the console logs above to see the actual response structure from the tool.") + hints = append(hints, "Add console.log() statements to inspect the response before accessing properties.") + hints = append(hints, "Example: console.log('searchResults:', searchResults);") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + } else { + hints = append(hints, "Check the error message above for details.") + hints = append(hints, "Check the console logs above to see tool responses and debug the issue.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + hints = append(hints, "Ensure you're using the correct syntax: const result = await .({ ...args });") + } + + return hints +} diff --git a/core/mcp/codemode_listfiles.go b/core/mcp/codemode_listfiles.go new file mode 100644 index 000000000..730a1083d --- /dev/null +++ b/core/mcp/codemode_listfiles.go @@ -0,0 +1,229 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// createListToolFilesTool creates the listToolFiles tool definition for code mode. +// This tool allows listing all available virtual .d.ts declaration files for connected MCP servers. +// The description is dynamically generated based on the configured CodeModeBindingLevel. +// +// Returns: +// - schemas.ChatTool: The tool definition for listing tool files +func (m *ToolsManager) createListToolFilesTool() schemas.ChatTool { + bindingLevel := m.GetCodeModeBindingLevel() + var description string + + if bindingLevel == schemas.CodeModeBindingLevelServer { + description = "Returns a tree structure listing all virtual .d.ts declaration files available for connected MCP servers. " + + "Each server has a corresponding file (e.g., servers/.d.ts) that contains definitions for all tools in that server. " + + "Use readToolFile to read a specific server file and see all available tools. " + + "In code, access tools via: await serverName.toolName({ args }). " + + "The server names used in code correspond to the human-readable names shown in this listing. " + + "This tool is generic and works with any set of servers connected at runtime. " + + "Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + + "If you have even the SLIGHTEST DOUBT that the current tools might not be useful for the task, check listToolFiles to discover all available tools." + } else { + description = "Returns a tree structure listing all virtual .d.ts declaration files available for connected MCP servers, organized by individual tool. " + + "Each tool has a corresponding file (e.g., servers//.d.ts) that contains definitions for that specific tool. " + + "Use readToolFile to read a specific tool file and see its parameters and usage. " + + "In code, access tools via: await serverName.toolName({ args }). " + + "The server names used in code correspond to the human-readable names shown in this listing. " + + "This tool is generic and works with any set of servers connected at runtime. " + + "Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + + "If you have even the SLIGHTEST DOUBT that the current tools might not be useful for the task, check listToolFiles to discover all available tools." + } + + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: ToolTypeListToolFiles, + Description: schemas.Ptr(description), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{}, + Required: []string{}, + }, + }, + } +} + +// handleListToolFiles handles the listToolFiles tool call. +// It builds a tree structure listing all virtual .d.ts files available for code mode clients. +// The structure depends on the CodeModeBindingLevel: +// - "server": servers/.d.ts (one file per server) +// - "tool": servers//.d.ts (one file per tool) +// +// Parameters: +// - ctx: Context for accessing client tools +// - toolCall: The tool call request containing no arguments +// +// Returns: +// - *schemas.ChatMessage: A tool response message containing the file tree structure +// - error: Any error that occurred during processing +func (m *ToolsManager) handleListToolFiles(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + + if len(availableToolsPerClient) == 0 { + responseText := "No servers are currently connected. There are no virtual .d.ts files available. " + + "Please ensure servers are connected before using this tool." + return createToolResponseMessage(toolCall, responseText), nil + } + + // Get the code mode binding level + bindingLevel := m.GetCodeModeBindingLevel() + + // Build file list based on binding level + var files []string + codeModeServerCount := 0 + + for clientName, tools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient { + continue + } + codeModeServerCount++ + + if bindingLevel == schemas.CodeModeBindingLevelServer { + // Server-level: one file per server + files = append(files, fmt.Sprintf("servers/%s.d.ts", clientName)) + } else { + // Tool-level: one file per tool + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name != "" { + toolFileName := fmt.Sprintf("servers/%s/%s.d.ts", clientName, tool.Function.Name) + files = append(files, toolFileName) + } + } + } + } + + if codeModeServerCount == 0 { + responseText := "Servers are connected but none are configured for code mode. " + + "There are no virtual .d.ts files available." + return createToolResponseMessage(toolCall, responseText), nil + } + + // Build tree structure from file list + responseText := buildVFSTree(files) + return createToolResponseMessage(toolCall, responseText), nil +} + +// VFS tree node structure for building hierarchical file structure +type treeNode struct { + isDirectory bool + children map[string]*treeNode +} + +// buildVFSTree creates a hierarchical tree structure from a flat list of file paths. +// It groups files by directory and formats them with proper indentation. +// +// Example input: +// - ["servers/calculator.d.ts", "servers/youtube.d.ts"] +// - ["servers/calculator/add.d.ts", "servers/youtube/GET_CHANNELS.d.ts"] +// +// Example output for server-level: +// servers/ +// calculator.d.ts +// youtube.d.ts +// +// Example output for tool-level: +// servers/ +// calculator/ +// add.d.ts +// youtube/ +// GET_CHANNELS.d.ts +func buildVFSTree(files []string) string { + if len(files) == 0 { + return "" + } + + root := &treeNode{ + isDirectory: true, + children: make(map[string]*treeNode), + } + + // Parse all files and build tree structure + for _, file := range files { + parts := strings.Split(file, "/") + current := root + + // Create all intermediate directories and final file + for i, part := range parts { + if _, exists := current.children[part]; !exists { + current.children[part] = &treeNode{ + isDirectory: i < len(parts)-1, // Last part is file, not directory + children: make(map[string]*treeNode), + } + } + current = current.children[part] + } + } + + // Render tree structure with proper indentation + var lines []string + renderTreeNode(root, "", &lines, true) + + return strings.Join(lines, "\n") +} + +// renderTreeNode recursively renders a tree node and its children with proper indentation. +func renderTreeNode(node *treeNode, indent string, lines *[]string, isRoot bool) { + // Get sorted keys for consistent output + var keys []string + for key := range node.children { + keys = append(keys, key) + } + + // Simple bubble sort for small lists (good enough for this use case) + for i := 0; i < len(keys); i++ { + for j := i + 1; j < len(keys); j++ { + if keys[j] < keys[i] { + keys[i], keys[j] = keys[j], keys[i] + } + } + } + + for _, key := range keys { + child := node.children[key] + + // Format the line + var line string + if isRoot { + // Root level - no indentation + if child.isDirectory { + line = key + "/" + } else { + line = key + } + } else { + // Non-root levels - add indentation + if child.isDirectory { + line = indent + key + "/" + } else { + line = indent + key + } + } + + *lines = append(*lines, line) + + // Recurse into children + if child.isDirectory && len(child.children) > 0 { + var nextIndent string + if isRoot { + nextIndent = " " + } else { + nextIndent = indent + " " + } + renderTreeNode(child, nextIndent, lines, false) + } + } +} diff --git a/core/mcp/codemode_readfile.go b/core/mcp/codemode_readfile.go new file mode 100644 index 000000000..776a6ac3b --- /dev/null +++ b/core/mcp/codemode_readfile.go @@ -0,0 +1,503 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// createReadToolFileTool creates the readToolFile tool definition for code mode. +// This tool allows reading virtual .d.ts declaration files for specific MCP servers/tools, +// generating TypeScript type definitions from the server's tool schemas. +// The description is dynamically generated based on the configured CodeModeBindingLevel. +// +// Returns: +// - schemas.ChatTool: The tool definition for reading tool files +func (m *ToolsManager) createReadToolFileTool() schemas.ChatTool { + bindingLevel := m.GetCodeModeBindingLevel() + + var fileNameDescription, toolDescription string + + if bindingLevel == schemas.CodeModeBindingLevelServer { + fileNameDescription = "The virtual filename from listToolFiles in format: servers/.d.ts (e.g., 'calculator.d.ts')" + toolDescription = "Reads a virtual .d.ts declaration file for a specific MCP server, generating TypeScript type definitions " + + "for all tools available on that server. The fileName should be in format servers/.d.ts as listed by listToolFiles. " + + "The function performs case-insensitive matching and removes the .d.ts extension. " + + "Optionally, you can specify startLine and endLine (1-based, inclusive) to read only a portion of the file. " + + "IMPORTANT: Line numbers are 1-based, not 0-based. The first line is line 1, not line 0. " + + "This generates TypeScript type definitions describing all tools in the server and their argument types, " + + "enabling code-mode execution. Each tool can be accessed in code via: await serverName.toolName({ args }). " + + "Always follow this workflow: first use listToolFiles to see available servers, then use readToolFile to understand " + + "all available tool definitions for a server, and finally use executeToolCode to execute your code." + } else { + fileNameDescription = "The virtual filename from listToolFiles in format: servers//.d.ts (e.g., 'calculator/add.d.ts')" + toolDescription = "Reads a virtual .d.ts declaration file for a specific tool, generating TypeScript type definitions " + + "for that individual tool. The fileName should be in format servers//.d.ts as listed by listToolFiles. " + + "The function performs case-insensitive matching and removes the .d.ts extension. " + + "Optionally, you can specify startLine and endLine (1-based, inclusive) to read only a portion of the file. " + + "IMPORTANT: Line numbers are 1-based, not 0-based. The first line is line 1, not line 0. " + + "This generates TypeScript type definitions for a single tool, describing its parameters and usage, " + + "enabling focused code-mode execution. The tool can be accessed in code via: await serverName.toolName({ args }). " + + "Always follow this workflow: first use listToolFiles to see available tools, then use readToolFile to understand " + + "a specific tool's definition, and finally use executeToolCode to execute your code." + } + + readToolFileProps := schemas.OrderedMap{ + "fileName": map[string]interface{}{ + "type": "string", + "description": fileNameDescription, + }, + "startLine": map[string]interface{}{ + "type": "number", + "description": "Optional 1-based starting line number for partial file read (inclusive). Note: Line numbers start at 1, not 0. The first line is line 1.", + }, + "endLine": map[string]interface{}{ + "type": "number", + "description": "Optional 1-based ending line number for partial file read (inclusive)", + }, + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: ToolTypeReadToolFile, + Description: schemas.Ptr(toolDescription), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &readToolFileProps, + Required: []string{"fileName"}, + }, + }, + } +} + +// handleReadToolFile handles the readToolFile tool call. +// It reads a virtual .d.ts file for a specific MCP server/tool, generates TypeScript type definitions, +// and optionally returns a portion of the file based on line range parameters. +// Supports both server-level files (e.g., "calculator.d.ts") and tool-level files (e.g., "calculator/add.d.ts"). +// +// Parameters: +// - ctx: Context for accessing client tools +// - toolCall: The tool call request containing fileName and optional startLine/endLine +// +// Returns: +// - *schemas.ChatMessage: A tool response message containing the TypeScript definitions +// - error: Any error that occurred during processing +func (m *ToolsManager) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments: %v", err) + } + + fileName, ok := arguments["fileName"].(string) + if !ok || fileName == "" { + return nil, fmt.Errorf("fileName parameter is required and must be a string") + } + + // Parse the file path to extract server name and optional tool name + serverName, toolName, isToolLevel := parseVFSFilePath(fileName) + + // Get available tools per client + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + + // Find matching client + var matchedClientName string + var matchedTools []schemas.ChatTool + matchCount := 0 + + for clientName, tools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { + continue + } + + clientNameLower := strings.ToLower(clientName) + serverNameLower := strings.ToLower(serverName) + + if clientNameLower == serverNameLower { + matchCount++ + if matchCount > 1 { + // Multiple matches found + errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName) + for name := range availableToolsPerClient { + if strings.ToLower(name) == serverNameLower { + errorMsg += fmt.Sprintf(" - %s\n", name) + } + } + errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity." + return createToolResponseMessage(toolCall, errorMsg), nil + } + + matchedClientName = clientName + + if isToolLevel { + // Tool-level: filter to specific tool + var foundTool *schemas.ChatTool + toolNameLower := strings.ToLower(toolName) + for i, tool := range tools { + if tool.Function != nil && strings.ToLower(tool.Function.Name) == toolNameLower { + foundTool = &tools[i] + break + } + } + + if foundTool == nil { + availableTools := make([]string, 0) + for _, tool := range tools { + if tool.Function != nil { + availableTools = append(availableTools, tool.Function.Name) + } + } + errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName) + for _, t := range availableTools { + errorMsg += fmt.Sprintf(" - %s/%s.d.ts\n", clientName, t) + } + return createToolResponseMessage(toolCall, errorMsg), nil + } + + matchedTools = []schemas.ChatTool{*foundTool} + } else { + // Server-level: use all tools + matchedTools = tools + } + } + } + + if matchedClientName == "" { + // Build helpful error message with available files + bindingLevel := m.GetCodeModeBindingLevel() + var availableFiles []string + + for name := range availableToolsPerClient { + if bindingLevel == schemas.CodeModeBindingLevelServer { + availableFiles = append(availableFiles, fmt.Sprintf("%s.d.ts", name)) + } else { + client := m.clientManager.GetClientByName(name) + if client != nil && client.ExecutionConfig.IsCodeModeClient { + if tools, ok := availableToolsPerClient[name]; ok { + for _, tool := range tools { + if tool.Function != nil { + availableFiles = append(availableFiles, fmt.Sprintf("%s/%s.d.ts", name, tool.Function.Name)) + } + } + } + } + } + } + + errorMsg := fmt.Sprintf("No server found matching '%s'. Available virtual files are:\n", serverName) + for _, f := range availableFiles { + errorMsg += fmt.Sprintf(" - %s\n", f) + } + return createToolResponseMessage(toolCall, errorMsg), nil + } + + // Generate TypeScript definitions + fileContent := generateTypeDefinitions(matchedClientName, matchedTools, isToolLevel) + lines := strings.Split(fileContent, "\n") + totalLines := len(lines) + + // Handle line slicing if provided + var startLine, endLine *int + if sl, ok := arguments["startLine"].(float64); ok { + slInt := int(sl) + startLine = &slInt + } + if el, ok := arguments["endLine"].(float64); ok { + elInt := int(el) + endLine = &elInt + } + + if startLine != nil || endLine != nil { + start := 1 + if startLine != nil { + start = *startLine + } + end := totalLines + if endLine != nil { + end = *endLine + } + + // Validate line numbers + if start < 1 || start > totalLines { + errorMsg := fmt.Sprintf("Invalid startLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%v, totalLines=%d", + start, totalLines, start, endLine, totalLines) + return createToolResponseMessage(toolCall, errorMsg), nil + } + if end < 1 || end > totalLines { + errorMsg := fmt.Sprintf("Invalid endLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%d, totalLines=%d", + end, totalLines, start, end, totalLines) + return createToolResponseMessage(toolCall, errorMsg), nil + } + if start > end { + errorMsg := fmt.Sprintf("Invalid line range: startLine (%d) must be less than or equal to endLine (%d). Total lines in file: %d", + start, end, totalLines) + return createToolResponseMessage(toolCall, errorMsg), nil + } + + // Slice lines (convert to 0-based indexing) + selectedLines := lines[start-1 : end] + fileContent = strings.Join(selectedLines, "\n") + } + + return createToolResponseMessage(toolCall, fileContent), nil +} + +// HELPER FUNCTIONS + +// parseVFSFilePath parses a VFS file path and extracts the server name and optional tool name. +// For server-level paths (e.g., "calculator.d.ts"), returns (serverName="calculator", toolName="", isToolLevel=false) +// For tool-level paths (e.g., "calculator/add.d.ts"), returns (serverName="calculator", toolName="add", isToolLevel=true) +// +// Parameters: +// - fileName: The virtual file path from listToolFiles +// +// Returns: +// - serverName: The name of the MCP server +// - toolName: The name of the tool (empty for server-level) +// - isToolLevel: Whether this is a tool-level path +func parseVFSFilePath(fileName string) (serverName, toolName string, isToolLevel bool) { + // Remove .d.ts extension + basePath := strings.TrimSuffix(fileName, ".d.ts") + + // Remove "servers/" prefix if present + basePath = strings.TrimPrefix(basePath, "servers/") + + // Check for path separator + parts := strings.Split(basePath, "/") + if len(parts) == 2 { + // Tool-level: "serverName/toolName" + return parts[0], parts[1], true + } + // Server-level: "serverName" + return basePath, "", false +} + +// generateTypeDefinitions generates TypeScript type definitions from ChatTool schemas +// with comprehensive comments to help LLMs understand how to use the tools. +// It creates interfaces for tool inputs and responses, along with function declarations. +// +// Parameters: +// - clientName: Name of the MCP client/server +// - tools: List of chat tools to generate definitions for +// - isToolLevel: Whether this is a tool-level definition (single tool) or server-level (all tools) +// +// Returns: +// - string: Complete TypeScript declaration file content +func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isToolLevel bool) string { + var sb strings.Builder + + // Write comprehensive header comment + sb.WriteString("// ============================================================================\n") + if isToolLevel && len(tools) == 1 && tools[0].Function != nil { + // Tool-level: show individual tool name + sb.WriteString(fmt.Sprintf("// Type definitions for %s.%s tool\n", clientName, tools[0].Function.Name)) + } else { + // Server-level: show all tools in server + sb.WriteString(fmt.Sprintf("// Type definitions for %s MCP server\n", clientName)) + } + sb.WriteString("// ============================================================================\n") + sb.WriteString("//\n") + if isToolLevel && len(tools) == 1 { + sb.WriteString("// This file contains TypeScript type definitions for a specific tool on this MCP server.\n") + } else { + sb.WriteString("// This file contains TypeScript type definitions for all tools available on this MCP server.\n") + } + sb.WriteString("// These definitions enable code-mode execution as described in the MCP code execution pattern.\n") + sb.WriteString("//\n") + sb.WriteString("// USAGE INSTRUCTIONS:\n") + sb.WriteString("// 1. Each tool has an input interface (e.g., ToolNameInput) that defines the required parameters\n") + sb.WriteString("// 2. Each tool has a function declaration showing how to call it\n") + sb.WriteString("// 3. To use these tools in executeToolCode, you would call them like:\n") + sb.WriteString("// const result = await .({ ...args });\n") + sb.WriteString("//\n") + sb.WriteString("// NOTE: The server name used in executeToolCode is the same as the display name shown here.\n") + sb.WriteString("// ============================================================================\n\n") + + // Generate interfaces and function declarations for each tool + for _, tool := range tools { + if tool.Function == nil || tool.Function.Name == "" { + continue + } + + originalToolName := tool.Function.Name + // Parse tool name for property name compatibility (used in virtual TypeScript files) + toolName := parseToolName(originalToolName) + description := "" + if tool.Function.Description != nil { + description = *tool.Function.Description + } + + // Generate input interface with detailed comments + inputInterfaceName := toPascalCase(toolName) + "Input" + sb.WriteString("// ----------------------------------------------------------------------------\n") + sb.WriteString(fmt.Sprintf("// Tool: %s\n", toolName)) + sb.WriteString("// ----------------------------------------------------------------------------\n") + if description != "" { + sb.WriteString(fmt.Sprintf("// Description: %s\n", description)) + } + sb.WriteString(fmt.Sprintf("// Input interface for %s\n", toolName)) + sb.WriteString(fmt.Sprintf("// This interface defines all parameters that can be passed to the %s tool.\n", toolName)) + sb.WriteString(fmt.Sprintf("interface %s {\n", inputInterfaceName)) + + if tool.Function.Parameters != nil && tool.Function.Parameters.Properties != nil { + props := *tool.Function.Parameters.Properties + required := make(map[string]bool) + if tool.Function.Parameters.Required != nil { + for _, req := range tool.Function.Parameters.Required { + required[req] = true + } + } + + // Sort properties for consistent output + propNames := make([]string, 0, len(props)) + for name := range props { + propNames = append(propNames, name) + } + // Simple alphabetical sort + for i := 0; i < len(propNames)-1; i++ { + for j := i + 1; j < len(propNames); j++ { + if propNames[i] > propNames[j] { + propNames[i], propNames[j] = propNames[j], propNames[i] + } + } + } + + for _, propName := range propNames { + prop := props[propName] + propMap, ok := prop.(map[string]interface{}) + if !ok { + continue + } + + tsType := jsonSchemaToTypeScript(propMap) + optional := "" + if !required[propName] { + optional = "?" + } + + propDesc := "" + if desc, ok := propMap["description"].(string); ok && desc != "" { + propDesc = fmt.Sprintf(" // %s", desc) + } else { + propDesc = fmt.Sprintf(" // %s parameter", propName) + } + + requiredNote := "" + if required[propName] { + requiredNote = " (required)" + } else { + requiredNote = " (optional)" + } + + sb.WriteString(fmt.Sprintf(" %s%s: %s;%s%s\n", propName, optional, tsType, propDesc, requiredNote)) + } + } + + sb.WriteString("}\n\n") + + // Generate response interface with helpful comments + responseInterfaceName := toPascalCase(toolName) + "Response" + sb.WriteString(fmt.Sprintf("// Response interface for %s\n", toolName)) + sb.WriteString("// The actual response structure depends on the tool implementation.\n") + sb.WriteString("// This is a placeholder interface - the actual response may contain different fields.\n") + sb.WriteString(fmt.Sprintf("interface %s {\n", responseInterfaceName)) + sb.WriteString(" // Response structure depends on the tool implementation\n") + sb.WriteString(" // Common fields may include: result, error, data, etc.\n") + sb.WriteString(" [key: string]: any;\n") + sb.WriteString("}\n\n") + + // Generate function declaration with usage example + sb.WriteString(fmt.Sprintf("// Function declaration for %s\n", toolName)) + if description != "" { + sb.WriteString(fmt.Sprintf("// %s\n", description)) + } + sb.WriteString("//\n") + sb.WriteString("// Usage example in executeToolCode:\n") + sb.WriteString(fmt.Sprintf("// const result = await .%s({ ... });\n", toolName)) + sb.WriteString("// // Replace with the actual server name/ID\n") + sb.WriteString(fmt.Sprintf("// // Replace { ... } with the appropriate %sInput object\n", inputInterfaceName)) + sb.WriteString(fmt.Sprintf("export async function %s(input: %s): Promise<%s>;\n\n", toolName, inputInterfaceName, responseInterfaceName)) + } + + return sb.String() +} + +// jsonSchemaToTypeScript converts a JSON Schema type definition to a TypeScript type string. +// It handles basic types, arrays, enums, and defaults to "any" for unknown types. +// +// Parameters: +// - prop: JSON Schema property definition map +// +// Returns: +// - string: TypeScript type string representation +func jsonSchemaToTypeScript(prop map[string]interface{}) string { + // Check for explicit type + if typeVal, ok := prop["type"].(string); ok { + switch typeVal { + case "string": + return "string" + case "number", "integer": + return "number" + case "boolean": + return "boolean" + case "array": + itemsType := "any" + if items, ok := prop["items"].(map[string]interface{}); ok { + itemsType = jsonSchemaToTypeScript(items) + } + return fmt.Sprintf("%s[]", itemsType) + case "object": + return "object" + case "null": + return "null" + } + } + + // Check for enum + if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 { + enumStrs := make([]string, 0, len(enum)) + for _, e := range enum { + enumStrs = append(enumStrs, fmt.Sprintf("%q", e)) + } + return strings.Join(enumStrs, " | ") + } + + // Default to any + return "any" +} + +// toPascalCase converts a string to PascalCase format. +// It splits on underscores, hyphens, and spaces, then capitalizes the first letter +// of each word and lowercases the rest. +// +// Parameters: +// - s: Input string to convert +// +// Returns: +// - string: PascalCase formatted string +func toPascalCase(s string) string { + if s == "" { + return s + } + parts := strings.FieldsFunc(s, func(r rune) bool { + return r == '_' || r == '-' || r == ' ' + }) + result := "" + for _, part := range parts { + if len(part) > 0 { + result += strings.ToUpper(part[:1]) + strings.ToLower(part[1:]) + } + } + if result == "" { + return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) + } + return result +} diff --git a/core/mcp/health_monitor.go b/core/mcp/health_monitor.go new file mode 100644 index 000000000..6a55938fe --- /dev/null +++ b/core/mcp/health_monitor.go @@ -0,0 +1,231 @@ +package mcp + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +const ( + // Health check configuration + DefaultHealthCheckInterval = 10 * time.Second // Interval between health checks + DefaultHealthCheckTimeout = 5 * time.Second // Timeout for each health check + MaxConsecutiveFailures = 5 // Number of failures before marking as unhealthy +) + +// ClientHealthMonitor tracks the health status of an MCP client +type ClientHealthMonitor struct { + manager *MCPManager + clientID string + interval time.Duration + timeout time.Duration + maxConsecutiveFailures int + mu sync.Mutex + ticker *time.Ticker + ctx context.Context + cancel context.CancelFunc + isMonitoring bool + consecutiveFailures int +} + +// NewClientHealthMonitor creates a new health monitor for an MCP client +func NewClientHealthMonitor( + manager *MCPManager, + clientID string, + interval time.Duration, +) *ClientHealthMonitor { + if interval == 0 { + interval = DefaultHealthCheckInterval + } + + return &ClientHealthMonitor{ + manager: manager, + clientID: clientID, + interval: interval, + timeout: DefaultHealthCheckTimeout, + maxConsecutiveFailures: MaxConsecutiveFailures, + isMonitoring: false, + consecutiveFailures: 0, + } +} + +// Start begins monitoring the client's health in a background goroutine +func (chm *ClientHealthMonitor) Start() { + chm.mu.Lock() + defer chm.mu.Unlock() + + if chm.isMonitoring { + return // Already monitoring + } + + chm.isMonitoring = true + chm.ctx, chm.cancel = context.WithCancel(context.Background()) + chm.ticker = time.NewTicker(chm.interval) + + go chm.monitorLoop() + logger.Debug(fmt.Sprintf("%s Health monitor started for client %s (interval: %v)", MCPLogPrefix, chm.clientID, chm.interval)) +} + +// Stop stops monitoring the client's health +func (chm *ClientHealthMonitor) Stop() { + chm.mu.Lock() + defer chm.mu.Unlock() + + if !chm.isMonitoring { + return // Not monitoring + } + + chm.isMonitoring = false + if chm.ticker != nil { + chm.ticker.Stop() + } + if chm.cancel != nil { + chm.cancel() + } + logger.Debug(fmt.Sprintf("%s Health monitor stopped for client %s", MCPLogPrefix, chm.clientID)) +} + +// monitorLoop runs the health check loop +func (chm *ClientHealthMonitor) monitorLoop() { + for { + select { + case <-chm.ctx.Done(): + return + case <-chm.ticker.C: + chm.performHealthCheck() + } + } +} + +// performHealthCheck performs a health check on the client +func (chm *ClientHealthMonitor) performHealthCheck() { + // Get the client connection + chm.manager.mu.RLock() + clientState, exists := chm.manager.clientMap[chm.clientID] + chm.manager.mu.RUnlock() + + if !exists { + chm.Stop() + return + } + + if clientState.Conn == nil { + // Client not connected, mark as disconnected + chm.updateClientState(schemas.MCPConnectionStateDisconnected) + chm.incrementFailures() + return + } + + // Perform ping with timeout + ctx, cancel := context.WithTimeout(context.Background(), chm.timeout) + defer cancel() + + err := clientState.Conn.Ping(ctx) + if err != nil { + chm.incrementFailures() + + // After max consecutive failures, mark as disconnected + if chm.getConsecutiveFailures() >= chm.maxConsecutiveFailures { + chm.updateClientState(schemas.MCPConnectionStateDisconnected) + } + } else { + // Health check passed + chm.resetFailures() + chm.updateClientState(schemas.MCPConnectionStateConnected) + } +} + +// updateClientState updates the client's connection state +func (chm *ClientHealthMonitor) updateClientState(state schemas.MCPConnectionState) { + chm.manager.mu.Lock() + clientState, exists := chm.manager.clientMap[chm.clientID] + if !exists { + chm.manager.mu.Unlock() + return + } + + // Only update if state changed + stateChanged := clientState.State != state + if stateChanged { + clientState.State = state + } + chm.manager.mu.Unlock() + + // Log after releasing the lock + if stateChanged { + logger.Info(fmt.Sprintf("%s Client %s connection state changed to: %s", MCPLogPrefix, chm.clientID, state)) + } +} + +// incrementFailures increments the consecutive failure counter +func (chm *ClientHealthMonitor) incrementFailures() { + chm.mu.Lock() + defer chm.mu.Unlock() + chm.consecutiveFailures++ +} + +// resetFailures resets the consecutive failure counter +func (chm *ClientHealthMonitor) resetFailures() { + chm.mu.Lock() + defer chm.mu.Unlock() + chm.consecutiveFailures = 0 +} + +// getConsecutiveFailures returns the current consecutive failure count +func (chm *ClientHealthMonitor) getConsecutiveFailures() int { + chm.mu.Lock() + defer chm.mu.Unlock() + return chm.consecutiveFailures +} + +// HealthMonitorManager manages all client health monitors +type HealthMonitorManager struct { + monitors map[string]*ClientHealthMonitor + mu sync.RWMutex +} + +// NewHealthMonitorManager creates a new health monitor manager +func NewHealthMonitorManager() *HealthMonitorManager { + return &HealthMonitorManager{ + monitors: make(map[string]*ClientHealthMonitor), + } +} + +// StartMonitoring starts monitoring a specific client +func (hmm *HealthMonitorManager) StartMonitoring(monitor *ClientHealthMonitor) { + hmm.mu.Lock() + defer hmm.mu.Unlock() + + // Stop any existing monitor for this client + if existing, ok := hmm.monitors[monitor.clientID]; ok { + existing.Stop() + } + + hmm.monitors[monitor.clientID] = monitor + monitor.Start() +} + +// StopMonitoring stops monitoring a specific client +func (hmm *HealthMonitorManager) StopMonitoring(clientID string) { + hmm.mu.Lock() + defer hmm.mu.Unlock() + + if monitor, ok := hmm.monitors[clientID]; ok { + monitor.Stop() + delete(hmm.monitors, clientID) + } +} + +// StopAll stops all monitoring +func (hmm *HealthMonitorManager) StopAll() { + hmm.mu.Lock() + defer hmm.mu.Unlock() + + for _, monitor := range hmm.monitors { + monitor.Stop() + } + hmm.monitors = make(map[string]*ClientHealthMonitor) +} diff --git a/core/mcp/init.go b/core/mcp/init.go new file mode 100644 index 000000000..d0eb389c1 --- /dev/null +++ b/core/mcp/init.go @@ -0,0 +1,9 @@ +package mcp + +import "github.com/maximhq/bifrost/core/schemas" + +var logger schemas.Logger + +func SetLogger(l schemas.Logger) { + logger = l +} diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go new file mode 100644 index 000000000..c7ec34d25 --- /dev/null +++ b/core/mcp/mcp.go @@ -0,0 +1,288 @@ +package mcp + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" + + "github.com/mark3labs/mcp-go/server" +) + +// ============================================================================ +// CONSTANTS +// ============================================================================ + +const ( + // MCP defaults and identifiers + BifrostMCPVersion = "1.0.0" // Version identifier for Bifrost + BifrostMCPClientName = "BifrostClient" // Name for internal Bifrost MCP client + BifrostMCPClientKey = "bifrostInternal" // Key for internal Bifrost client in clientMap + MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix + MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment + + // Context keys for client filtering in requests + // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). + // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. + // Request context filtering takes priority over client config - context can override client exclusions. + MCPContextKeyIncludeClients schemas.BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering + MCPContextKeyIncludeTools schemas.BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName/toolName" format) +) + +// ============================================================================ +// TYPE DEFINITIONS +// ============================================================================ + +// MCPManager manages MCP integration for Bifrost core. +// It provides a bridge between Bifrost and various MCP servers, supporting +// both local tool hosting and external MCP server connections. +type MCPManager struct { + ctx context.Context + toolsManager *ToolsManager // Handler for MCP tools + server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based) + clientMap map[string]*schemas.MCPClientState // Map of MCP client names to their configurations + mu sync.RWMutex // Read-write mutex for thread-safe operations + serverRunning bool // Track whether local MCP server is running + healthMonitorManager *HealthMonitorManager // Manager for client health monitors +} + +// MCPToolFunction is a generic function type for handling tool calls with typed arguments. +// T represents the expected argument structure for the tool. +type MCPToolFunction[T any] func(args T) (string, error) + +// ============================================================================ +// CONSTRUCTOR AND INITIALIZATION +// ============================================================================ + +// NewMCPManager creates and initializes a new MCP manager instance. +// +// Parameters: +// - config: MCP configuration including server port and client configs +// - logger: Logger instance for structured logging (uses default if nil) +// +// Returns: +// - *MCPManager: Initialized manager instance +// - error: Any initialization error +func NewMCPManager(ctx context.Context, config schemas.MCPConfig, logger schemas.Logger) *MCPManager { + SetLogger(logger) + // Set default values + if config.ToolManagerConfig == nil { + config.ToolManagerConfig = &schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + MaxAgentDepth: schemas.DefaultMaxAgentDepth, + } + } + // Creating new instance + manager := &MCPManager{ + ctx: ctx, + clientMap: make(map[string]*schemas.MCPClientState), + healthMonitorManager: NewHealthMonitorManager(), + } + manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc) + // Process client configs: create client map entries and establish connections + if len(config.ClientConfigs) > 0 { + for _, clientConfig := range config.ClientConfigs { + if err := manager.AddClient(clientConfig); err != nil { + logger.Warn(fmt.Sprintf("%s Failed to add MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err)) + } + } + } + logger.Info(MCPLogPrefix + " MCP Manager initialized") + return manager +} + +// AddToolsToRequest parses available MCP tools from the context and adds them to the request. +// It respects context-based filtering for clients and tools, and returns the modified request +// with tools attached. +// +// Parameters: +// - ctx: Context containing optional client/tool filtering keys +// - req: The Bifrost request to add tools to +// +// Returns: +// - *schemas.BifrostRequest: The request with tools added +func (m *MCPManager) AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { + return m.toolsManager.ParseAndAddToolsToRequest(ctx, req) +} + +func (m *MCPManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { + return m.toolsManager.GetAvailableTools(ctx) +} + +// ExecuteChatTool executes a single tool call and returns the result as a chat message. +// This is the primary tool executor and is used by both Chat Completions and Responses APIs. +// +// The method accepts tool calls in Chat API format (ChatAssistantMessageToolCall) and returns +// results in Chat API format (ChatMessage). For Responses API users: +// - Convert ResponsesToolMessage to ChatAssistantMessageToolCall using ToChatAssistantMessageToolCall() +// - Execute the tool with this method +// - Convert the result back using ChatMessage.ToResponsesToolMessage() +// +// Alternatively, use ExecuteResponsesTool() in the ToolsManager for a type-safe wrapper +// that handles format conversions automatically. +// +// Parameters: +// - ctx: Context for the tool execution +// - toolCall: The tool call to execute in Chat API format +// +// Returns: +// - *schemas.ChatMessage: The result message containing tool execution output +// - error: Any error that occurred during tool execution +func (m *MCPManager) ExecuteChatTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + return m.toolsManager.ExecuteChatTool(ctx, toolCall) +} + +// ExecuteResponsesTool executes a single tool call and returns the result as a responses message. + +// - ctx: Context for the tool execution +// - toolCall: The tool call to execute in Responses API format +// +// Returns: +// - *schemas.ResponsesMessage: The result message containing tool execution output +// - error: Any error that occurred during tool execution +func (m *MCPManager) ExecuteResponsesTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, error) { + return m.toolsManager.ExecuteResponsesTool(ctx, toolCall) +} + +// UpdateToolManagerConfig updates the configuration for the tool manager. +// This allows runtime updates to settings like execution timeout and max agent depth. +// +// Parameters: +// - config: The new tool manager configuration to apply +func (m *MCPManager) UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) { + m.toolsManager.UpdateConfig(config) +} + +// CheckAndExecuteAgentForChatRequest checks if the chat response contains tool calls, +// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls +// are present, it returns the original response unchanged. +// +// Agent mode enables autonomous tool execution where: +// 1. Tool calls are automatically executed +// 2. Results are fed back to the LLM +// 3. The loop continues until no more tool calls are made or max depth is reached +// 4. Non-auto-executable tools are returned to the caller +// +// This method is available for both Chat Completions and Responses APIs. +// For Responses API, use CheckAndExecuteAgentForResponsesRequest(). +// +// Parameters: +// - ctx: Context for the agent execution +// - req: The original chat request +// - response: The initial chat response that may contain tool calls +// - makeReq: Function to make subsequent chat requests during agent execution +// +// Returns: +// - *schemas.BifrostChatResponse: The final response after agent execution (or original if no tool calls) +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *MCPManager) CheckAndExecuteAgentForChatRequest( + ctx *context.Context, + req *schemas.BifrostChatRequest, + response *schemas.BifrostChatResponse, + makeReq func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), +) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + if makeReq == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "makeReq is required to execute agent mode", + }, + } + } + // Check if initial response has tool calls + if !hasToolCallsForChatResponse(response) { + logger.Debug("No tool calls detected, returning response") + return response, nil + } + // Execute agent mode + return m.toolsManager.ExecuteAgentForChatRequest(ctx, req, response, makeReq) +} + +// CheckAndExecuteAgentForResponsesRequest checks if the responses response contains tool calls, +// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls +// are present, it returns the original response unchanged. +// +// Agent mode for Responses API works identically to Chat API: +// 1. Detects tool calls in the response (function_call messages) +// 2. Automatically executes tools in parallel when possible +// 3. Feeds results back to the LLM in Responses API format +// 4. Continues the loop until no more tool calls or max depth reached +// 5. Returns non-auto-executable tools to the caller +// +// Format Handling: +// This method automatically handles format conversions: +// - Responses tool calls (ResponsesToolMessage) are converted to Chat format for execution +// - Tool execution results are converted back to Responses format (ResponsesMessage) +// - All conversions use the adapters in agent_adaptors.go and converters in schemas/mux.go +// +// This provides full feature parity between Chat Completions and Responses APIs for tool execution. +// +// Parameters: +// - ctx: Context for the agent execution +// - req: The original responses request +// - response: The initial responses response that may contain tool calls +// - makeReq: Function to make subsequent responses requests during agent execution +// +// Returns: +// - *schemas.BifrostResponsesResponse: The final response after agent execution (or original if no tool calls) +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *MCPManager) CheckAndExecuteAgentForResponsesRequest( + ctx *context.Context, + req *schemas.BifrostResponsesRequest, + response *schemas.BifrostResponsesResponse, + makeReq func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), +) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + if makeReq == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "makeReq is required to execute agent mode", + }, + } + } + // Check if initial response has tool calls + if !hasToolCallsForResponsesResponse(response) { + logger.Debug("No tool calls detected, returning response") + return response, nil + } + // Execute agent mode + return m.toolsManager.ExecuteAgentForResponsesRequest(ctx, req, response, makeReq) +} + +// Cleanup performs cleanup of all MCP resources including clients and local server. +// This function safely disconnects all MCP clients (HTTP, STDIO, and SSE) and +// cleans up the local MCP server. It handles proper cancellation of SSE contexts +// and closes all transport connections. +// +// Returns: +// - error: Always returns nil, but maintains error interface for consistency +func (m *MCPManager) Cleanup() error { + // Stop all health monitors first + m.healthMonitorManager.StopAll() + + m.mu.Lock() + defer m.mu.Unlock() + + // Disconnect all external MCP clients + for id := range m.clientMap { + if err := m.removeClientUnsafe(id); err != nil { + logger.Error("%s Failed to remove MCP client %s: %v", MCPLogPrefix, id, err) + } + } + + // Clear the client map + m.clientMap = make(map[string]*schemas.MCPClientState) + + // Clear local server reference + // Note: mark3labs/mcp-go STDIO server cleanup is handled automatically + if m.server != nil { + logger.Info(MCPLogPrefix + " Clearing local MCP server reference") + m.server = nil + m.serverRunning = false + } + + logger.Info(MCPLogPrefix + " MCP cleanup completed") + return nil +} diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go new file mode 100644 index 000000000..7bb048645 --- /dev/null +++ b/core/mcp/toolmanager.go @@ -0,0 +1,557 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +type ClientManager interface { + GetClientByName(clientName string) *schemas.MCPClientState + GetClientForTool(toolName string) *schemas.MCPClientState + GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool +} + +type ToolsManager struct { + toolExecutionTimeout atomic.Value + maxAgentDepth atomic.Int32 + codeModeBindingLevel atomic.Value // Stores CodeModeBindingLevel + clientManager ClientManager + logMu sync.Mutex // Protects concurrent access to logs slice in codemode execution + + // Function to fetch a new request ID for each tool call result message in agent mode, + // this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user. + // This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode. + // If not provider, same request ID is used for all tool call result messages without any overrides. + fetchNewRequestIDFunc func(ctx context.Context) string +} + +const ( + ToolTypeListToolFiles string = "listToolFiles" + ToolTypeReadToolFile string = "readToolFile" + ToolTypeExecuteToolCode string = "executeToolCode" +) + +// NewToolsManager creates and initializes a new tools manager instance. +// It validates the configuration, sets defaults if needed, and initializes atomic values +// for thread-safe configuration updates. +// +// Parameters: +// - config: Tool manager configuration with execution timeout and max agent depth +// - clientManager: Client manager interface for accessing MCP clients and tools +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for agent mode +// +// Returns: +// - *ToolsManager: Initialized tools manager instance +func NewToolsManager(config *schemas.MCPToolManagerConfig, clientManager ClientManager, fetchNewRequestIDFunc func(ctx context.Context) string) *ToolsManager { + if config == nil { + config = &schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + MaxAgentDepth: schemas.DefaultMaxAgentDepth, + CodeModeBindingLevel: schemas.CodeModeBindingLevelServer, + } + } + if config.MaxAgentDepth <= 0 { + config.MaxAgentDepth = schemas.DefaultMaxAgentDepth + } + if config.ToolExecutionTimeout <= 0 { + config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout + } + // Default to server-level binding if not specified + if config.CodeModeBindingLevel == "" { + config.CodeModeBindingLevel = schemas.CodeModeBindingLevelServer + } + manager := &ToolsManager{ + clientManager: clientManager, + fetchNewRequestIDFunc: fetchNewRequestIDFunc, + } + // Initialize atomic values + manager.toolExecutionTimeout.Store(config.ToolExecutionTimeout) + manager.maxAgentDepth.Store(int32(config.MaxAgentDepth)) + manager.codeModeBindingLevel.Store(config.CodeModeBindingLevel) + + logger.Info(fmt.Sprintf("%s tool manager initialized with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)) + return manager +} + +// GetAvailableTools returns the available tools for the given context. +func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + // Flatten tools from all clients into a single slice, avoiding duplicates + var availableTools []schemas.ChatTool + var includeCodeModeTools bool + // Track tool names to prevent duplicates + seenToolNames := make(map[string]bool) + + for clientName, clientTools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if client.ExecutionConfig.IsCodeModeClient { + includeCodeModeTools = true + } else { + // Add tools from this client, checking for duplicates + for _, tool := range clientTools { + if tool.Function != nil && tool.Function.Name != "" { + if !seenToolNames[tool.Function.Name] { + availableTools = append(availableTools, tool) + seenToolNames[tool.Function.Name] = true + } + } + } + } + } + + if includeCodeModeTools { + codeModeTools := []schemas.ChatTool{ + m.createListToolFilesTool(), + m.createReadToolFileTool(), + m.createExecuteToolCodeTool(), + } + // Add code mode tools, checking for duplicates + for _, tool := range codeModeTools { + if tool.Function != nil && tool.Function.Name != "" { + if !seenToolNames[tool.Function.Name] { + availableTools = append(availableTools, tool) + seenToolNames[tool.Function.Name] = true + } + } + } + } + + return availableTools +} + +// buildIntegrationDuplicateCheckMap builds a map of tool names to check for duplicates +// based on the integration user agent. This includes both direct tool names and +// integration-specific naming patterns from existing tools in the request. +// +// Parameters: +// - existingTools: List of existing tools in the request +// - integrationUserAgent: Integration user agent string (e.g., "claude-cli") +// - availableToolsPerClient: Map of client names to their available tools (for reverse pattern matching) +// +// Returns: +// - map[string]bool: Map of tool names/patterns to check against +func buildIntegrationDuplicateCheckMap(existingTools []schemas.ChatTool, integrationUserAgent string) map[string]bool { + duplicateCheckMap := make(map[string]bool) + + // Add direct tool names + for _, tool := range existingTools { + if tool.Function != nil && tool.Function.Name != "" { + duplicateCheckMap[tool.Function.Name] = true + } + } + + // Add integration-specific patterns from existing tools + switch integrationUserAgent { + case "claude-cli": + // Claude CLI uses pattern: mcp__{foreign_name}__{tool_name} + // The middle part is a foreign name we cannot check for, so we extract the last part + // Examples: + // mcp__bifrost__executeToolCode -> executeToolCode + // mcp__bifrost__listToolFiles -> listToolFiles + // mcp__bifrost__readToolFile -> readToolFile + // mcp__calculator__calculator_add -> calculator_add + for _, tool := range existingTools { + if tool.Function != nil && tool.Function.Name != "" { + existingToolName := tool.Function.Name + // Check if existing tool matches Claude CLI pattern: mcp__*__{tool_name} + if strings.HasPrefix(existingToolName, "mcp__") { + // Split on __ and take the last entry (the tool_name) + parts := strings.Split(existingToolName, "__") + if len(parts) >= 3 { + toolName := parts[len(parts)-1] // Last part is the tool name + // Map Claude CLI pattern back to our tool name format + // This handles both regular MCP tools and code mode tools + if toolName != "" { + duplicateCheckMap[toolName] = true + // Also keep the original pattern for direct matching + duplicateCheckMap[existingToolName] = true + } + } + } + } + } + // Add more integration-specific patterns here as needed + // case "another-integration": + // // Add patterns for other integrations + } + + return duplicateCheckMap +} + +// ParseAndAddToolsToRequest parses the available tools per client and adds them to the Bifrost request. +// +// Parameters: +// - ctx: Execution context +// - req: Bifrost request +// - availableToolsPerClient: Map of client name to its available tools +// +// Returns: +// - *schemas.BifrostRequest: Bifrost request with MCP tools added +func (m *ToolsManager) ParseAndAddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { + // MCP is only supported for chat and responses requests + if req.ChatRequest == nil && req.ResponsesRequest == nil { + return req + } + + availableTools := m.GetAvailableTools(ctx) + + if len(availableTools) == 0 { + return req + } + + // Get integration user agent for duplicate checking + var integrationUserAgentStr string + integrationUserAgent := ctx.Value(schemas.BifrostContextKey("integration-user-agent")) + if integrationUserAgent != nil { + if str, ok := integrationUserAgent.(string); ok { + integrationUserAgentStr = str + } + } + + if len(availableTools) > 0 { + switch req.RequestType { + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + // Only allocate new Params if it's nil to preserve caller-supplied settings + if req.ChatRequest.Params == nil { + req.ChatRequest.Params = &schemas.ChatParameters{} + } + + tools := req.ChatRequest.Params.Tools + + // Build integration-aware duplicate check map + duplicateCheckMap := buildIntegrationDuplicateCheckMap(tools, integrationUserAgentStr) + + // Add MCP tools that are not already present + for _, mcpTool := range availableTools { + // Skip tools with nil Function or empty Name + if mcpTool.Function == nil || mcpTool.Function.Name == "" { + continue + } + + toolName := mcpTool.Function.Name + + // Check for duplicates using integration-aware logic + if !duplicateCheckMap[toolName] { + tools = append(tools, mcpTool) + // Update the map to prevent duplicates within MCP tools as well + duplicateCheckMap[toolName] = true + } + } + req.ChatRequest.Params.Tools = tools + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + // Only allocate new Params if it's nil to preserve caller-supplied settings + if req.ResponsesRequest.Params == nil { + req.ResponsesRequest.Params = &schemas.ResponsesParameters{} + } + + tools := req.ResponsesRequest.Params.Tools + + // Convert Responses tools to ChatTool format for duplicate checking + existingChatTools := make([]schemas.ChatTool, 0, len(tools)) + for _, tool := range tools { + if tool.Name != nil { + existingChatTools = append(existingChatTools, schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: *tool.Name, + }, + }) + } + } + + // Build integration-aware duplicate check map + duplicateCheckMap := buildIntegrationDuplicateCheckMap(existingChatTools, integrationUserAgentStr) + + // Add MCP tools that are not already present + for _, mcpTool := range availableTools { + // Skip tools with nil Function or empty Name + if mcpTool.Function == nil || mcpTool.Function.Name == "" { + continue + } + + toolName := mcpTool.Function.Name + + // Check for duplicates using integration-aware logic + if !duplicateCheckMap[toolName] { + responsesTool := mcpTool.ToResponsesTool() + // Skip if the converted tool has nil Name + if responsesTool.Name == nil { + continue + } + + tools = append(tools, *responsesTool) + // Update the map to prevent duplicates within MCP tools as well + duplicateCheckMap[toolName] = true + } + } + req.ResponsesRequest.Params.Tools = tools + } + } + return req +} + +// ============================================================================ +// TOOL REGISTRATION AND DISCOVERY +// ============================================================================ + +// ExecuteChatTool executes a tool call in Chat Completions API format and returns the result as a chat tool message. +// This is the primary tool executor that works with both Chat Completions and Responses APIs. +// +// For Responses API users, use ExecuteResponsesTool() for a more type-safe interface. +// However, internally this method is format-agnostic - it executes the tool and returns +// a ChatMessage which can then be converted to ResponsesMessage via ToResponsesToolMessage(). +// +// Parameters: +// - ctx: Execution context +// - toolCall: The tool call to execute (from assistant message) +// +// Returns: +// - *schemas.ChatMessage: Tool message with execution result +// - error: Any execution error +func (m *ToolsManager) ExecuteChatTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + if toolCall.Function.Name == nil { + return nil, fmt.Errorf("tool call missing function name") + } + toolName := *toolCall.Function.Name + + // Handle code mode tools + switch toolName { + case ToolTypeListToolFiles: + return m.handleListToolFiles(ctx, toolCall) + case ToolTypeReadToolFile: + return m.handleReadToolFile(ctx, toolCall) + case ToolTypeExecuteToolCode: + return m.handleExecuteToolCode(ctx, toolCall) + default: + // Check if the user has permission to execute the tool call + availableTools := m.clientManager.GetToolPerClient(ctx) + toolFound := false + for _, tools := range availableTools { + for _, mcpTool := range tools { + if mcpTool.Function != nil && mcpTool.Function.Name == toolName { + toolFound = true + break + } + } + if toolFound { + break + } + } + + if !toolFound { + return nil, fmt.Errorf("tool '%s' is not available or not permitted", toolName) + } + + client := m.clientManager.GetClientForTool(toolName) + if client == nil { + return nil, fmt.Errorf("client not found for tool %s", toolName) + } + + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err) + } + + // Strip the client name prefix from tool name before calling MCP server + // The MCP server expects the original tool name, not the prefixed version + originalToolName := stripClientPrefix(toolName, client.ExecutionConfig.Name) + + // Call the tool via MCP client -> MCP server + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: originalToolName, + Arguments: arguments, + }, + } + + logger.Debug(fmt.Sprintf("%s Starting tool execution: %s via client: %s", MCPLogPrefix, toolName, client.ExecutionConfig.Name)) + + // Create timeout context for tool execution + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + if callErr != nil { + // Check if it was a timeout error + if toolCtx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) + } + logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr) + return nil, fmt.Errorf("MCP tool call failed: %v", callErr) + } + + logger.Debug(fmt.Sprintf("%s Tool execution completed: %s", MCPLogPrefix, toolName)) + + // Extract text from MCP response + responseText := extractTextFromMCPResponse(toolResponse, toolName) + + // Create tool response message + return createToolResponseMessage(toolCall, responseText), nil + } +} + +// ExecuteToolForResponses executes a tool call from a Responses API tool message and returns +// the result in Responses API format. This is a type-safe wrapper around ExecuteTool that +// handles the conversion between Responses and Chat API formats. +// +// This method: +// 1. Converts the Responses tool message to Chat API format +// 2. Executes the tool using the standard tool executor +// 3. Converts the result back to Responses API format +// +// Parameters: +// - ctx: Execution context +// - toolMessage: The Responses API tool message to execute +// - callID: The original call ID from the Responses API +// +// Returns: +// - *schemas.ResponsesMessage: Tool result message in Responses API format +// - error: Any execution error +// +// Example: +// +// responsesToolMsg := &schemas.ResponsesToolMessage{ +// Name: Ptr("calculate"), +// Arguments: Ptr("{\"x\": 10, \"y\": 20}"), +// } +// resultMsg, err := toolsManager.ExecuteResponsesTool(ctx, responsesToolMsg, "call-123") +// // resultMsg is a ResponsesMessage with type=function_call_output +func (m *ToolsManager) ExecuteResponsesTool( + ctx context.Context, + toolMessage *schemas.ResponsesToolMessage, +) (*schemas.ResponsesMessage, error) { + if toolMessage == nil { + return nil, fmt.Errorf("tool message is nil") + } + if toolMessage.Name == nil { + return nil, fmt.Errorf("tool call missing function name") + } + + // Convert Responses format to Chat format for execution + chatToolCall := toolMessage.ToChatAssistantMessageToolCall() + if chatToolCall == nil { + return nil, fmt.Errorf("failed to convert Responses tool message to Chat format") + } + + // Execute the tool using the standard executor + chatResult, err := m.ExecuteChatTool(ctx, *chatToolCall) + if err != nil { + return nil, err + } + + // Convert the result back to Responses format + responsesMessage := chatResult.ToResponsesToolMessage() + if responsesMessage == nil { + return nil, fmt.Errorf("failed to convert tool result to Responses format") + } + + return responsesMessage, nil +} + +// ExecuteAgentForChatRequest executes agent mode for a chat request, handling +// iterative tool calls up to the configured maximum depth. It delegates to the +// shared agent execution logic with the manager's configuration and dependencies. +// +// Parameters: +// - ctx: Context for agent execution +// - req: The original chat request +// - resp: The initial chat response containing tool calls +// - makeReq: Function to make subsequent chat requests during agent execution +// +// Returns: +// - *schemas.BifrostChatResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *ToolsManager) ExecuteAgentForChatRequest( + ctx *context.Context, + req *schemas.BifrostChatRequest, + resp *schemas.BifrostChatResponse, + makeReq func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), +) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return ExecuteAgentForChatRequest( + ctx, + int(m.maxAgentDepth.Load()), + req, + resp, + makeReq, + m.fetchNewRequestIDFunc, + m.ExecuteChatTool, + m.clientManager, + ) +} + +// ExecuteAgentForResponsesRequest executes agent mode for a responses request, handling +// iterative tool calls up to the configured maximum depth. It delegates to the +// shared agent execution logic with the manager's configuration and dependencies. +// +// Parameters: +// - ctx: Context for agent execution +// - req: The original responses request +// - resp: The initial responses response containing tool calls +// - makeReq: Function to make subsequent responses requests during agent execution +// +// Returns: +// - *schemas.BifrostResponsesResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *ToolsManager) ExecuteAgentForResponsesRequest( + ctx *context.Context, + req *schemas.BifrostResponsesRequest, + resp *schemas.BifrostResponsesResponse, + makeReq func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), +) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return ExecuteAgentForResponsesRequest( + ctx, + int(m.maxAgentDepth.Load()), + req, + resp, + makeReq, + m.fetchNewRequestIDFunc, + m.ExecuteChatTool, + m.clientManager, + ) +} + +// UpdateConfig updates tool manager configuration atomically. +// This method is safe to call concurrently from multiple goroutines. +func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) { + if config == nil { + return + } + if config.ToolExecutionTimeout > 0 { + m.toolExecutionTimeout.Store(config.ToolExecutionTimeout) + } + if config.MaxAgentDepth > 0 { + m.maxAgentDepth.Store(int32(config.MaxAgentDepth)) + } + if config.CodeModeBindingLevel != "" { + m.codeModeBindingLevel.Store(config.CodeModeBindingLevel) + } + + logger.Info(fmt.Sprintf("%s tool manager configuration updated with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)) +} + +// GetCodeModeBindingLevel returns the current code mode binding level. +// This method is safe to call concurrently from multiple goroutines. +func (m *ToolsManager) GetCodeModeBindingLevel() schemas.CodeModeBindingLevel { + val := m.codeModeBindingLevel.Load() + if val == nil { + return schemas.CodeModeBindingLevelServer + } + return val.(schemas.CodeModeBindingLevel) +} diff --git a/core/mcp/utils.go b/core/mcp/utils.go new file mode 100644 index 000000000..3fe8b9e7c --- /dev/null +++ b/core/mcp/utils.go @@ -0,0 +1,567 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "regexp" + "slices" + "strings" + "unicode" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +// GetClientForTool safely finds a client that has the specified tool. +// Returns a copy of the client state to avoid data races. Callers should be aware +// that fields like Conn and ToolMap are still shared references and may be modified +// by other goroutines, but the struct itself is safe from concurrent modification. +func (m *MCPManager) GetClientForTool(toolName string) *schemas.MCPClientState { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, client := range m.clientMap { + if _, exists := client.ToolMap[toolName]; exists { + // Return a copy to prevent TOCTOU race conditions + // The caller receives a snapshot of the client state at this point in time + clientCopy := *client + return &clientCopy + } + } + return nil +} + +// GetToolPerClient returns all tools from connected MCP clients. +// Applies client filtering if specified in the context. +// Returns a map of client name to its available tools. +// Parameters: +// - ctx: Execution context +// +// Returns: +// - map[string][]schemas.ChatTool: Map of client name to its available tools +func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { + m.mu.RLock() + defer m.mu.RUnlock() + + var includeClients []string + + // Extract client filtering from request context + if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { + includeClients = existingIncludeClients + } + + tools := make(map[string][]schemas.ChatTool) + for _, client := range m.clientMap { + // Use client name as the key (not ID) + clientName := client.ExecutionConfig.Name + + // Apply client filtering logic + if !shouldIncludeClient(clientName, includeClients) { + logger.Debug(fmt.Sprintf("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, clientName)) + continue + } + + logger.Debug(fmt.Sprintf("Checking tools for MCP client %s with tools to execute: %v", clientName, client.ExecutionConfig.ToolsToExecute)) + + // Add all tools from this client + for toolName, tool := range client.ToolMap { + // Check if tool should be skipped based on client configuration + if shouldSkipToolForConfig(toolName, client.ExecutionConfig) { + logger.Debug(fmt.Sprintf("%s Skipping MCP tool %s: not in tools to execute list", MCPLogPrefix, toolName)) + continue + } + + // Check if tool should be skipped based on request context + if shouldSkipToolForRequest(ctx, clientName, toolName) { + logger.Debug(fmt.Sprintf("%s Skipping MCP tool %s: not in include tools list", MCPLogPrefix, toolName)) + continue + } + + tools[clientName] = append(tools[clientName], tool) + } + if len(tools[clientName]) > 0 { + logger.Debug(fmt.Sprintf("%s Added %d tools for MCP client %s", MCPLogPrefix, len(tools[clientName]), clientName)) + } + } + return tools +} + +// GetClientByName returns a client by name. +// +// Parameters: +// - clientName: Name of the client to get +// +// Returns: +// - *schemas.MCPClientState: Client state if found, nil otherwise +func (m *MCPManager) GetClientByName(clientName string) *schemas.MCPClientState { + m.mu.RLock() + defer m.mu.RUnlock() + for _, client := range m.clientMap { + if client.ExecutionConfig.Name == clientName { + // Return a copy to prevent TOCTOU race conditions + // The caller receives a snapshot of the client state at this point in time + clientCopy := *client + return &clientCopy + } + } + return nil +} + +// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks. +func retrieveExternalTools(ctx context.Context, client *client.Client, clientName string) (map[string]schemas.ChatTool, error) { + // Get available tools from external server + listRequest := mcp.ListToolsRequest{ + PaginatedRequest: mcp.PaginatedRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsList), + }, + }, + } + + toolsResponse, err := client.ListTools(ctx, listRequest) + if err != nil { + return nil, fmt.Errorf("failed to list tools: %v", err) + } + + if toolsResponse == nil { + return make(map[string]schemas.ChatTool), nil // No tools available + } + + tools := make(map[string]schemas.ChatTool) + + // toolsResponse is already a ListToolsResult + for _, mcpTool := range toolsResponse.Tools { + // Convert MCP tool schema to Bifrost format + bifrostTool := convertMCPToolToBifrostSchema(&mcpTool) + // Prefix tool name with client name to make it permanent + prefixedToolName := fmt.Sprintf("%s_%s", clientName, mcpTool.Name) + // Update the tool's function name to match the prefixed name + if bifrostTool.Function != nil { + bifrostTool.Function.Name = prefixedToolName + } + tools[prefixedToolName] = bifrostTool + } + + return tools, nil +} + +// shouldIncludeClient determines if a client should be included based on filtering rules. +func shouldIncludeClient(clientName string, includeClients []string) bool { + // If includeClients is specified (not nil), apply whitelist filtering + if includeClients != nil { + // Handle empty array [] - means no clients are included + if len(includeClients) == 0 { + return false // No clients allowed + } + + // Handle wildcard "*" - if present, all clients are included + if slices.Contains(includeClients, "*") { + return true // All clients allowed + } + + // Check if specific client is in the list + return slices.Contains(includeClients, clientName) + } + + // Default: include all clients when no filtering specified (nil case) + return true +} + +// shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap). +func shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bool { + // If ToolsToExecute is specified (not nil), apply filtering + if config.ToolsToExecute != nil { + // Handle empty array [] - means no tools are allowed + if len(config.ToolsToExecute) == 0 { + return true // No tools allowed + } + + // Handle wildcard "*" - if present, all tools are allowed + if slices.Contains(config.ToolsToExecute, "*") { + return false // All tools allowed + } + + // Check if specific tool is in the allowed list + return !slices.Contains(config.ToolsToExecute, toolName) // Tool not in allowed list + } + + return true // Tool is skipped (nil is treated as [] - no tools) +} + +// canAutoExecuteTool checks if a tool can be auto-executed based on client configuration. +// Returns true if the tool can be auto-executed, false otherwise. +func canAutoExecuteTool(toolName string, config schemas.MCPClientConfig) bool { + // First check if tool is in ToolsToExecute (must be executable first) + if shouldSkipToolForConfig(toolName, config) { + return false // Tool is not in ToolsToExecute, so it cannot be auto-executed + } + + // If ToolsToAutoExecute is specified (not nil), apply filtering + if config.ToolsToAutoExecute != nil { + // Handle empty array [] - means no tools are auto-executed + if len(config.ToolsToAutoExecute) == 0 { + return false // No tools auto-executed + } + + // Handle wildcard "*" - if present, all tools are auto-executed + if slices.Contains(config.ToolsToAutoExecute, "*") { + return true // All tools auto-executed + } + + // Check if specific tool is in the auto-execute list + return slices.Contains(config.ToolsToAutoExecute, toolName) + } + + return false // Tool is not auto-executed (nil is treated as [] - no tools) +} + +// shouldSkipToolForRequest checks if a tool should be skipped based on the request context. +func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool { + includeTools := ctx.Value(MCPContextKeyIncludeTools) + + if includeTools != nil { + // Try []string first (preferred type) + if includeToolsList, ok := includeTools.([]string); ok { + // Handle empty array [] - means no tools are included + if len(includeToolsList) == 0 { + return true // No tools allowed + } + + // Handle wildcard "clientName/*" - if present, all tools are included for this client + if slices.Contains(includeToolsList, fmt.Sprintf("%s/*", clientName)) { + return false // All tools allowed + } + + // Check if specific tool is in the list (format: clientName/toolName) + fullToolName := fmt.Sprintf("%s/%s", clientName, toolName) + if slices.Contains(includeToolsList, fullToolName) { + return false // Tool is explicitly allowed + } + + // If includeTools is specified but this tool is not in it, skip it + return true + } + } + + return false // Tool is allowed (default when no filtering specified) +} + +// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format. +func convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.ChatTool { + var properties *schemas.OrderedMap + if len(mcpTool.InputSchema.Properties) > 0 { + orderedProps := make(schemas.OrderedMap, len(mcpTool.InputSchema.Properties)) + maps.Copy(orderedProps, mcpTool.InputSchema.Properties) + properties = &orderedProps + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: mcpTool.Name, + Description: schemas.Ptr(mcpTool.Description), + Parameters: &schemas.ToolFunctionParameters{ + Type: mcpTool.InputSchema.Type, + Properties: properties, + Required: mcpTool.InputSchema.Required, + }, + }, + } +} + +// extractTextFromMCPResponse extracts text content from an MCP tool response. +func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string { + if toolResponse == nil { + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) + } + + var result strings.Builder + for _, contentBlock := range toolResponse.Content { + // Handle typed content + switch content := contentBlock.(type) { + case mcp.TextContent: + result.WriteString(content.Text) + case mcp.ImageContent: + result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.AudioContent: + result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.EmbeddedResource: + result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type)) + default: + // Fallback: try to extract from map structure + if jsonBytes, err := json.Marshal(contentBlock); err == nil { + var contentMap map[string]interface{} + if json.Unmarshal(jsonBytes, &contentMap) == nil { + if text, ok := contentMap["text"].(string); ok { + result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text)) + continue + } + } + // Final fallback: serialize as JSON + result.WriteString(string(jsonBytes)) + } + } + } + + if result.Len() > 0 { + return strings.TrimSpace(result.String()) + } + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) +} + +// createToolResponseMessage creates a tool response message with the execution result. +func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage { + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: &responseText, + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: toolCall.ID, + }, + } +} + +// validateMCPClientConfig validates an MCP client configuration. +func validateMCPClientConfig(config *schemas.MCPClientConfig) error { + if strings.TrimSpace(config.ID) == "" { + return fmt.Errorf("id is required for MCP client config") + } + if err := validateMCPClientName(config.Name); err != nil { + return fmt.Errorf("invalid name for MCP client: %w", err) + } + if config.ConnectionType == "" { + return fmt.Errorf("connection type is required for MCP client config") + } + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSSE: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSTDIO: + if config.StdioConfig == nil { + return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeInProcess: + // InProcess requires a server instance to be provided programmatically + // This cannot be validated from JSON config - the server must be set when using the Go package + if config.InProcessServer == nil { + return fmt.Errorf("InProcessServer is required for InProcess connection type in client '%s' (Go package only)", config.Name) + } + default: + return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name) + } + return nil +} + +func validateMCPClientName(name string) error { + if strings.TrimSpace(name) == "" { + return fmt.Errorf("name is required for MCP client") + } + for _, r := range name { + if r > 127 { // non-ASCII + return fmt.Errorf("name must contain only ASCII characters") + } + } + if strings.Contains(name, "-") { + return fmt.Errorf("name cannot contain hyphens") + } + if strings.Contains(name, " ") { + return fmt.Errorf("name cannot contain spaces") + } + if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { + return fmt.Errorf("name cannot start with a number") + } + return nil +} + +// parseToolName parses the tool name to be JavaScript-compatible. +// It converts spaces and hyphens to underscores, removes invalid characters, and ensures +// the name starts with a valid JavaScript identifier character. +func parseToolName(toolName string) string { + if toolName == "" { + return "" + } + + var result strings.Builder + runes := []rune(toolName) + + // Process first character - must be letter, underscore, or dollar sign + if len(runes) > 0 { + first := runes[0] + if unicode.IsLetter(first) || first == '_' || first == '$' { + result.WriteRune(unicode.ToLower(first)) + } else { + // If first char is invalid, prefix with underscore + result.WriteRune('_') + if unicode.IsDigit(first) { + result.WriteRune(first) + } + } + } + + // Process remaining characters + for i := 1; i < len(runes); i++ { + r := runes[i] + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' { + result.WriteRune(unicode.ToLower(r)) + } else if unicode.IsSpace(r) || r == '-' { + // Replace spaces and hyphens with single underscore + // Avoid consecutive underscores + if result.Len() > 0 && result.String()[result.Len()-1] != '_' { + result.WriteRune('_') + } + } + // Skip other invalid characters + } + + parsed := result.String() + + // Remove trailing underscores + parsed = strings.TrimRight(parsed, "_") + + // Ensure we have at least one character + // Should never happen, but just in case + if parsed == "" { + return "tool" + } + + return parsed +} + +// extractToolCallsFromCode extracts tool calls from TypeScript code +// Tool calls are in the format: serverName.toolName(...) or await serverName.toolName(...) +func extractToolCallsFromCode(code string) ([]toolCallInfo, error) { + toolCalls := []toolCallInfo{} + + // Regex pattern to match tool calls: + // - Optional "await" keyword + // - Server name (identifier) + // - Dot + // - Tool name (identifier) + // - Opening parenthesis + // This pattern matches: await serverName.toolName( or serverName.toolName( + toolCallPattern := regexp.MustCompile(`(?:await\s+)?([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\.\s*([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\(`) + + // Find all matches + matches := toolCallPattern.FindAllStringSubmatch(code, -1) + for _, match := range matches { + if len(match) >= 3 { + serverName := match[1] + toolName := match[2] + toolCalls = append(toolCalls, toolCallInfo{ + serverName: serverName, + toolName: toolName, + }) + } + } + + return toolCalls, nil +} + +// isToolCallAllowedForCodeMode checks if a tool call is allowed based on allowedAutoExecutionTools map +func isToolCallAllowedForCodeMode(serverName, toolName string, allClientNames []string, allowedAutoExecutionTools map[string][]string) bool { + // Check if the server name is in the list of all client names + if !slices.Contains(allClientNames, serverName) { + // It can be a built-in JavaScript/TypeScript object, if not then downstream execution will fail with a runtime error. + return true + } + + // Get allowed tools for this server + allowedTools, exists := allowedAutoExecutionTools[serverName] + if !exists { + // Server not in allowed list, return false to prevent downstream execution. + return false + } + + // Check if wildcard "*" is present (all tools allowed) + if slices.Contains(allowedTools, "*") { + return true + } + + // Check if specific tool is in the allowed list + if slices.Contains(allowedTools, toolName) { + return true + } + + return false // Tool not in allowed list +} + +// hasToolCalls checks if a chat response contains tool calls that need to be executed +func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool { + if response == nil || len(response.Choices) == 0 { + return false + } + + choice := response.Choices[0] + + // If finish_reason is "stop", this indicates non-auto-executable tools that require user approval. + // Don't return true even if tool calls are present, as the agent loop should not process them. + if choice.FinishReason != nil && *choice.FinishReason == "stop" { + return false + } + + // Check finish reason + if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { + return true + } + + // Check if message has tool calls + if choice.ChatNonStreamResponseChoice != nil && + choice.ChatNonStreamResponseChoice.Message != nil && + choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil && + len(choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls) > 0 { + return true + } + + return false +} + +func hasToolCallsForResponsesResponse(response *schemas.BifrostResponsesResponse) bool { + if response == nil || len(response.Output) == 0 { + return false + } + + // Check if any output message is a tool call + for _, output := range response.Output { + if output.Type == nil { + continue + } + + // Check for tool call types + switch *output.Type { + case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeCustomToolCall: + // Verify that ResponsesToolMessage is actually set + if output.ResponsesToolMessage != nil { + return true + } + } + } + + return false +} + +// stripClientPrefix removes the client name prefix from a tool name. +// Tool names are stored with format "{clientName}_{toolName}", but when calling +// the MCP server, we need the original tool name without the prefix. +// +// Parameters: +// - prefixedToolName: Tool name with client prefix (e.g., "calculator_add") +// - clientName: Client name to strip (e.g., "calculator") +// +// Returns: +// - string: Original tool name without prefix (e.g., "add") +func stripClientPrefix(prefixedToolName, clientName string) string { + prefix := clientName + "_" + if strings.HasPrefix(prefixedToolName, prefix) { + return strings.TrimPrefix(prefixedToolName, prefix) + } + // If prefix doesn't match, return as-is (shouldn't happen, but be safe) + return prefixedToolName +} diff --git a/core/providers/nebius/nebius.go b/core/providers/nebius/nebius.go index 9a72779b8..e6be7657a 100644 --- a/core/providers/nebius/nebius.go +++ b/core/providers/nebius/nebius.go @@ -49,6 +49,7 @@ func NewNebiusProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* logger: logger, client: client, networkConfig: config.NetworkConfig, + sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, }, nil } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 86947f12a..cc9be26fe 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -2019,6 +2019,8 @@ func HandleOpenAITranscriptionRequest( return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) } + //TODO: add HandleProviderResponse here + // Parse raw response for RawResponse field var rawResponse interface{} if sendBackRawResponse { diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index d38283ac8..6646be966 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -124,11 +124,11 @@ const ( BifrostContextKeyRequestID BifrostContextKey = "request-id" // string BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct - BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost)) - BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost)) - BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost)) - BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost)) 0 for primary, 1 for first fallback, etc. - BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost) + BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. + BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string @@ -136,7 +136,8 @@ const ( BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) - BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost) + BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostMCPAgentOriginalRequestID BifrostContextKey = "bifrost-mcp-agent-original-request-id" // string (to store the original request ID for MCP agent mode) BifrostContextKeyStructuredOutputToolName BifrostContextKey = "bifrost-structured-output-tool-name" // string (to store the name of the structured output tool (set by bifrost)) BifrostContextKeyUserAgent BifrostContextKey = "bifrost-user-agent" // string (set by bifrost) ) @@ -486,7 +487,7 @@ type BifrostError struct { Error *ErrorField `json:"error"` AllowFallbacks *bool `json:"-"` // Optional: Controls fallback behavior (nil = true by default) StreamControl *StreamControl `json:"-"` // Optional: Controls stream behavior - ExtraFields BifrostErrorExtraFields `json:"extra_fields,omitempty"` + ExtraFields BifrostErrorExtraFields `json:"extra_fields"` } // StreamControl represents stream control options. diff --git a/core/schemas/context.go b/core/schemas/context.go index 6ff56eb1b..bcfae8108 100644 --- a/core/schemas/context.go +++ b/core/schemas/context.go @@ -171,6 +171,9 @@ func (bc *BifrostContext) SetValue(key, value any) { } bc.valuesMu.Lock() defer bc.valuesMu.Unlock() + if bc.userValues == nil { + bc.userValues = make(map[any]any) + } bc.userValues[key] = value } diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index e26409e12..f54998cc5 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -1,32 +1,69 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -// MCPServerInstance represents an MCP server instance for InProcess connections. -// This should be a *github.com/mark3labs/mcp-go/server.MCPServer instance. -// We use interface{} to avoid creating a dependency on the mcp-go package in schemas. -type MCPServerInstance interface{} +import ( + "context" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/server" +) // MCPConfig represents the configuration for MCP integration in Bifrost. // It enables tool auto-discovery and execution from local and external MCP servers. type MCPConfig struct { - ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations + ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations + ToolManagerConfig *MCPToolManagerConfig `json:"tool_manager_config,omitempty"` // MCP tool manager configuration + + // Function to fetch a new request ID for each tool call result message in agent mode, + // this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user. + // This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode. + // If not provider, same request ID is used for all tool call result messages without any overrides. + FetchNewRequestIDFunc func(ctx context.Context) string `json:"-"` +} + +type MCPToolManagerConfig struct { + ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` + MaxAgentDepth int `json:"max_agent_depth"` + CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" } +const ( + DefaultMaxAgentDepth = 10 + DefaultToolExecutionTimeout = 30 * time.Second +) + +// CodeModeBindingLevel defines how tools are exposed in the VFS for code execution +type CodeModeBindingLevel string + +const ( + CodeModeBindingLevelServer CodeModeBindingLevel = "server" + CodeModeBindingLevelTool CodeModeBindingLevel = "tool" +) + // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { ID string `json:"id"` // Client ID Name string `json:"name"` // Client name + IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) ConnectionString *string `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) Headers map[string]string `json:"headers,omitempty"` // Headers to send with the request - InProcessServer MCPServerInstance `json:"-"` // MCP server instance for in-process connections (Go package only) + InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Include-only list. // ToolsToExecute semantics: // - ["*"] => all tools are included // - [] => no tools are included (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => include only the specified tools + ToolsToAutoExecute []string `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. + // ToolsToAutoExecute semantics: + // - ["*"] => all tools are auto-executed + // - [] => no tools are auto-executed (deny-by-default) + // - nil/omitted => treated as [] (no tools) + // - ["tool1", "tool2"] => auto-execute only the specified tools + // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. } // MCPConnectionType defines the communication protocol for MCP connections @@ -54,9 +91,28 @@ const ( MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used ) +// MCPClientState represents a connected MCP client with its configuration and tools. +// It is used internally by the MCP manager to track the state of a connected MCP client. +type MCPClientState struct { + Name string // Unique name for this client + Conn *client.Client // Active MCP client connection + ExecutionConfig MCPClientConfig // Tool filtering settings + ToolMap map[string]ChatTool // Available tools mapped by name + ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management + CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) + State MCPConnectionState // Connection state (connected, disconnected, error) +} + +// MCPClientConnectionInfo stores metadata about how a client is connected. +type MCPClientConnectionInfo struct { + Type MCPConnectionType `json:"type"` // Connection type (HTTP, STDIO, SSE, or InProcess) + ConnectionURL *string `json:"connection_url,omitempty"` // HTTP/SSE endpoint URL (for HTTP/SSE connections) + StdioCommandString *string `json:"stdio_command_string,omitempty"` // Command string for display (for STDIO connections) +} + // MCPClient represents a connected MCP client with its configuration and tools, // and connection information, after it has been initialized. -// It is returned by GetMCPClients() method. +// It is returned by GetMCPClients() method in bifrost. type MCPClient struct { Config MCPClientConfig `json:"config"` // Tool filtering settings Tools []ChatToolFunction `json:"tools"` // Available tools diff --git a/core/schemas/mux.go b/core/schemas/mux.go index ccd058b17..83f36c701 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -113,6 +113,99 @@ func (rt *ResponsesTool) ToChatTool() *ChatTool { return ct } +// ToChatAssistantMessageToolCall converts a ResponsesToolMessage to ChatAssistantMessageToolCall format. +// This is useful for executing Responses API tool calls using the Chat API tool executor. +// +// Returns: +// - *ChatAssistantMessageToolCall: The converted tool call in Chat API format +// +// Example: +// +// responsesToolMsg := &ResponsesToolMessage{ +// CallID: Ptr("call-123"), +// Name: Ptr("calculate"), +// Arguments: Ptr("{\"x\": 10, \"y\": 20}"), +// } +// chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall() +func (rtm *ResponsesToolMessage) ToChatAssistantMessageToolCall() *ChatAssistantMessageToolCall { + if rtm == nil { + return nil + } + + toolCall := &ChatAssistantMessageToolCall{ + ID: rtm.CallID, + Type: Ptr("function"), + Function: ChatAssistantMessageToolCallFunction{ + Name: rtm.Name, + Arguments: "{}", // Default to empty JSON object for valid JSON unmarshaling + }, + } + + // Extract arguments string + if rtm.Arguments != nil { + toolCall.Function.Arguments = *rtm.Arguments + } + + return toolCall +} + +// ToResponsesToolMessage converts a ChatToolMessage (tool execution result) to ResponsesToolMessage format. +// This creates a function_call_output message suitable for the Responses API. +// +// Returns: +// - *ResponsesMessage: A ResponsesMessage with type=function_call_output containing the tool result +// +// Example: +// +// chatToolMsg := &ChatMessage{ +// Role: ChatMessageRoleTool, +// ChatToolMessage: &ChatToolMessage{ +// ToolCallID: Ptr("call-123"), +// }, +// Content: &ChatMessageContent{ +// ContentStr: Ptr("Result: 30"), +// }, +// } +// responsesMsg := chatToolMsg.ToResponsesToolMessage() +func (cm *ChatMessage) ToResponsesToolMessage() *ResponsesMessage { + if cm == nil || cm.ChatToolMessage == nil { + return nil + } + + msgType := ResponsesMessageTypeFunctionCallOutput + + respMsg := &ResponsesMessage{ + Type: &msgType, + ResponsesToolMessage: &ResponsesToolMessage{ + CallID: cm.ChatToolMessage.ToolCallID, + }, + } + + // Extract output from content + if cm.Content != nil { + if cm.Content.ContentStr != nil { + output := *cm.Content.ContentStr + respMsg.ResponsesToolMessage.Output = &ResponsesToolMessageOutputStruct{ + ResponsesToolCallOutputStr: &output, + } + } else if len(cm.Content.ContentBlocks) > 0 { + // For structured content blocks, convert to ResponsesMessageContentBlock + respBlocks := make([]ResponsesMessageContentBlock, len(cm.Content.ContentBlocks)) + for i, block := range cm.Content.ContentBlocks { + respBlocks[i] = ResponsesMessageContentBlock{ + Type: ResponsesMessageContentBlockType(block.Type), + Text: block.Text, + } + } + respMsg.ResponsesToolMessage.Output = &ResponsesToolMessageOutputStruct{ + ResponsesFunctionToolCallOutputBlocks: respBlocks, + } + } + } + + return respMsg +} + // ============================================================================= // TOOL CHOICE CONVERSION METHODS // ============================================================================= @@ -324,14 +417,17 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { role = ResponsesInputMessageRoleSystem case ChatMessageRoleTool: messageType = ResponsesMessageTypeFunctionCallOutput - role = ResponsesInputMessageRoleUser // Tool messages are typically user role in responses + role = "" // tool call output messages don't include a role field case ChatMessageRoleDeveloper: role = ResponsesInputMessageRoleDeveloper } rm := ResponsesMessage{ Type: &messageType, - Role: &role, + } + + if role != "" { + rm.Role = &role } // Handle refusal content specifically - use content blocks with ResponsesOutputMessageContentRefusal @@ -347,7 +443,10 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { } } else if cm.Content != nil && cm.Content.ContentStr != nil { // Convert regular string content (if input message then ContentStr, else ContentBlocks) - if cm.Role == ChatMessageRoleAssistant { + // Skip setting content for function_call_output - content should only be in output field + if messageType == ResponsesMessageTypeFunctionCallOutput { + // Don't set content for function_call_output - it will be set in ResponsesToolMessage.Output + } else if cm.Role == ChatMessageRoleAssistant { rm.Content = &ResponsesMessageContent{ ContentBlocks: []ResponsesMessageContentBlock{ {Type: ResponsesOutputMessageContentTypeText, Text: cm.Content.ContentStr}, @@ -360,57 +459,62 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { } } else if cm.Content != nil && cm.Content.ContentBlocks != nil { // Convert content blocks - responseBlocks := make([]ResponsesMessageContentBlock, len(cm.Content.ContentBlocks)) - for i, block := range cm.Content.ContentBlocks { - blockType := ResponsesMessageContentBlockType(block.Type) - - switch block.Type { - case ChatContentBlockTypeText: - if cm.Role == ChatMessageRoleAssistant { - blockType = ResponsesOutputMessageContentTypeText - } else { - blockType = ResponsesInputMessageContentBlockTypeText + // Skip setting content blocks for function_call_output + if messageType == ResponsesMessageTypeFunctionCallOutput { + // Don't set content for function_call_output - it will be set in ResponsesToolMessage.Output + } else { + responseBlocks := make([]ResponsesMessageContentBlock, len(cm.Content.ContentBlocks)) + for i, block := range cm.Content.ContentBlocks { + blockType := ResponsesMessageContentBlockType(block.Type) + + switch block.Type { + case ChatContentBlockTypeText: + if cm.Role == ChatMessageRoleAssistant { + blockType = ResponsesOutputMessageContentTypeText + } else { + blockType = ResponsesInputMessageContentBlockTypeText + } + case ChatContentBlockTypeImage: + blockType = ResponsesInputMessageContentBlockTypeImage + case ChatContentBlockTypeFile: + blockType = ResponsesInputMessageContentBlockTypeFile + case ChatContentBlockTypeInputAudio: + blockType = ResponsesInputMessageContentBlockTypeAudio } - case ChatContentBlockTypeImage: - blockType = ResponsesInputMessageContentBlockTypeImage - case ChatContentBlockTypeFile: - blockType = ResponsesInputMessageContentBlockTypeFile - case ChatContentBlockTypeInputAudio: - blockType = ResponsesInputMessageContentBlockTypeAudio - } - responseBlocks[i] = ResponsesMessageContentBlock{ - Type: blockType, - Text: block.Text, - } - - // Convert specific block types - if block.ImageURLStruct != nil { - responseBlocks[i].ResponsesInputMessageContentBlockImage = &ResponsesInputMessageContentBlockImage{ - ImageURL: &block.ImageURLStruct.URL, - Detail: block.ImageURLStruct.Detail, + responseBlocks[i] = ResponsesMessageContentBlock{ + Type: blockType, + Text: block.Text, } - } - if block.File != nil { - responseBlocks[i].ResponsesInputMessageContentBlockFile = &ResponsesInputMessageContentBlockFile{ - FileData: block.File.FileData, - Filename: block.File.Filename, + + // Convert specific block types + if block.ImageURLStruct != nil { + responseBlocks[i].ResponsesInputMessageContentBlockImage = &ResponsesInputMessageContentBlockImage{ + ImageURL: &block.ImageURLStruct.URL, + Detail: block.ImageURLStruct.Detail, + } } - responseBlocks[i].FileID = block.File.FileID - } - if block.InputAudio != nil { - format := "" - if block.InputAudio.Format != nil { - format = *block.InputAudio.Format + if block.File != nil { + responseBlocks[i].ResponsesInputMessageContentBlockFile = &ResponsesInputMessageContentBlockFile{ + FileData: block.File.FileData, + Filename: block.File.Filename, + } + responseBlocks[i].FileID = block.File.FileID } - responseBlocks[i].Audio = &ResponsesInputMessageContentBlockAudio{ - Data: block.InputAudio.Data, - Format: format, + if block.InputAudio != nil { + format := "" + if block.InputAudio.Format != nil { + format = *block.InputAudio.Format + } + responseBlocks[i].Audio = &ResponsesInputMessageContentBlockAudio{ + Data: block.InputAudio.Data, + Format: format, + } } } - } - rm.Content = &ResponsesMessageContent{ - ContentBlocks: responseBlocks, + rm.Content = &ResponsesMessageContent{ + ContentBlocks: responseBlocks, + } } } @@ -422,9 +526,18 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { } // If tool output content exists, add it to function_call_output - if rm.Content != nil && rm.Content.ContentStr != nil && *rm.Content.ContentStr != "" { + // For function_call_output, get content from cm.Content since rm.Content is not set + var outputContent *string + if messageType == ResponsesMessageTypeFunctionCallOutput { + // Get content directly from ChatMessage for function_call_output + if cm.Content != nil && cm.Content.ContentStr != nil && *cm.Content.ContentStr != "" { + outputContent = cm.Content.ContentStr + } + } + + if outputContent != nil { rm.ResponsesToolMessage.Output = &ResponsesToolMessageOutputStruct{ - ResponsesToolCallOutputStr: rm.Content.ContentStr, + ResponsesToolCallOutputStr: outputContent, } } } diff --git a/core/schemas/utils.go b/core/schemas/utils.go index 66ec247f9..8b58bf0e8 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -727,6 +727,92 @@ func deepCopyChatContentBlock(original ChatContentBlock) ChatContentBlock { return copy } +// DeepCopyChatTool creates a deep copy of a ChatTool +// to prevent shared data mutation between different plugin accumulators +func DeepCopyChatTool(original ChatTool) ChatTool { + copyTool := ChatTool{ + Type: original.Type, + } + + // Deep copy Function if present + if original.Function != nil { + copyTool.Function = &ChatToolFunction{ + Name: original.Function.Name, + } + + if original.Function.Description != nil { + copyDescription := *original.Function.Description + copyTool.Function.Description = ©Description + } + + if original.Function.Parameters != nil { + copyParams := &ToolFunctionParameters{ + Type: original.Function.Parameters.Type, + } + + if original.Function.Parameters.Description != nil { + copyParamDesc := *original.Function.Parameters.Description + copyParams.Description = ©ParamDesc + } + + if original.Function.Parameters.Required != nil { + copyParams.Required = make([]string, len(original.Function.Parameters.Required)) + copy(copyParams.Required, original.Function.Parameters.Required) + } + + if original.Function.Parameters.Properties != nil { + // Deep copy the map + copyProps := make(map[string]interface{}, len(*original.Function.Parameters.Properties)) + for k, v := range *original.Function.Parameters.Properties { + copyProps[k] = DeepCopy(v) + } + orderedProps := OrderedMap(copyProps) + copyParams.Properties = &orderedProps + } + + if original.Function.Parameters.Enum != nil { + copyParams.Enum = make([]string, len(original.Function.Parameters.Enum)) + copy(copyParams.Enum, original.Function.Parameters.Enum) + } + + if original.Function.Parameters.AdditionalProperties != nil { + copyAdditionalProps := *original.Function.Parameters.AdditionalProperties + copyParams.AdditionalProperties = ©AdditionalProps + } + + copyTool.Function.Parameters = copyParams + } + + if original.Function.Strict != nil { + copyStrict := *original.Function.Strict + copyTool.Function.Strict = ©Strict + } + } + + // Deep copy Custom if present + if original.Custom != nil { + copyTool.Custom = &ChatToolCustom{} + + if original.Custom.Format != nil { + copyFormat := &ChatToolCustomFormat{ + Type: original.Custom.Format.Type, + } + + if original.Custom.Format.Grammar != nil { + copyGrammar := &ChatToolCustomGrammarFormat{ + Definition: original.Custom.Format.Grammar.Definition, + Syntax: original.Custom.Format.Grammar.Syntax, + } + copyFormat.Grammar = copyGrammar + } + + copyTool.Custom.Format = copyFormat + } + } + + return copyTool +} + // DeepCopyResponsesMessage creates a deep copy of a ResponsesMessage // to prevent shared data mutation between different plugin accumulators func DeepCopyResponsesMessage(original ResponsesMessage) ResponsesMessage { diff --git a/docs/changelogs/v1.3.47.mdx b/docs/changelogs/v1.3.47.mdx index ef79a5c49..208138514 100644 --- a/docs/changelogs/v1.3.47.mdx +++ b/docs/changelogs/v1.3.47.mdx @@ -17,28 +17,28 @@ description: "v1.3.47 changelog - 2025-12-12" -feat: support for raw response accumulation for streaming -feat: support for raw request logging and sending back in response -feat: added support for reasoning in chat completions -feat: enhanced reasoning support in responses api -enhancement: improved internal inter provider conversions for integrations -feat: switched to gemini native api +- feat: support for raw response accumulation for streaming +- feat: support for raw request logging and sending back in response +- feat: added support for reasoning in chat completions +- feat: enhanced reasoning support in responses api +- enhancement: improved internal inter provider conversions for integrations +- feat: switched to gemini native api -feat: send back raw request in extra fields -feat: added support for reasoning in chat completions -feat: enhanced reasoning support in responses api -enhancement: improved internal inter provider conversions for integrations -feat: switched to gemini native api -feat: fallback to supported request type for custom models used in integration +- feat: send back raw request in extra fields +- feat: added support for reasoning in chat completions +- feat: enhanced reasoning support in responses api +- enhancement: improved internal inter provider conversions for integrations +- feat: switched to gemini native api +- feat: fallback to supported request type for custom models used in integration -feat: support raw response accumulation in stream accumulator -feat: support raw request configuration and logging -feat: added support for reasoning accumulation in stream accumulator -chore: updating core to 1.2.37 and framework to 1.1.47 +- feat: support raw response accumulation in stream accumulator +- feat: support raw request configuration and logging +- feat: added support for reasoning accumulation in stream accumulator +- chore: updating core to 1.2.37 and framework to 1.1.47 diff --git a/docs/features/governance/virtual-keys.mdx b/docs/features/governance/virtual-keys.mdx index 1ebf673a0..89c64670a 100644 --- a/docs/features/governance/virtual-keys.mdx +++ b/docs/features/governance/virtual-keys.mdx @@ -565,7 +565,7 @@ curl -X POST http://localhost:8080/v1/chat/completions \ { "error": { "type": "virtual_key_required", - "message": "x-bf-vk header is missing" + "message": "virtual key is missing in headers" } } ``` @@ -615,7 +615,7 @@ curl -X POST http://localhost:8080/v1/chat/completions \ { "error": { "type": "budget_exceeded", - "message": "Budget check failed: VK budget exceeded: 105.50 > 100.00 dollars" + "message": "Budget exceeded: VK budget exceeded: 105.50 > 100.00 dollars" } } ``` diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index f38d3107a..047676d36 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "encoding/json" "sort" + "strconv" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" @@ -46,6 +47,9 @@ type ClientConfig struct { AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB EnableLiteLLMFallbacks bool `json:"enable_litellm_fallbacks"` // Enable litellm-specific fallbacks for text completion for Groq + MCPAgentDepth int `json:"mcp_agent_depth"` // The maximum depth for MCP agent mode tool execution + MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds + MCPCodeModeBindingLevel string `json:"mcp_code_mode_binding_level"` // Code mode binding level: "server" or "tool" ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } @@ -97,6 +101,24 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write([]byte("enableLiteLLMFallbacks:false")) } + if c.MCPAgentDepth > 0 { + hash.Write([]byte("mcpAgentDepth:" + strconv.Itoa(c.MCPAgentDepth))) + } else { + hash.Write([]byte("mcpAgentDepth:0")) + } + + if c.MCPToolExecutionTimeout > 0 { + hash.Write([]byte("mcpToolExecutionTimeout:" + strconv.Itoa(c.MCPToolExecutionTimeout))) + } else { + hash.Write([]byte("mcpToolExecutionTimeout:0")) + } + + if c.MCPCodeModeBindingLevel != "" { + hash.Write([]byte("mcpCodeModeBindingLevel:" + c.MCPCodeModeBindingLevel)) + } else { + hash.Write([]byte("mcpCodeModeBindingLevel:server")) + } + // Hash integer fields data, err := sonic.Marshal(c.InitialPoolSize) if err != nil { diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 93975ff16..4ac808636 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -3,7 +3,10 @@ package configstore import ( "context" "fmt" + "log" "strconv" + "strings" + "unicode" "github.com/google/uuid" "github.com/maximhq/bifrost/core/schemas" @@ -77,6 +80,12 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationMissingProviderColumnInKeyTable(ctx, db); err != nil { return err } + if err := migrationAddToolsToAutoExecuteJSONColumn(ctx, db); err != nil { + return err + } + if err := migrationAddIsCodeModeClientColumn(ctx, db); err != nil { + return err + } if err := migrationAddLogRetentionDaysColumn(ctx, db); err != nil { return err } @@ -86,6 +95,15 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddBatchAndCachePricingColumns(ctx, db); err != nil { return err } + if err := migrationAddMCPAgentDepthAndMCPToolExecutionTimeoutColumns(ctx, db); err != nil { + return err + } + if err := migrationAddMCPCodeModeBindingLevelColumn(ctx, db); err != nil { + return err + } + if err := migrationNormalizeMCPClientNames(ctx, db); err != nil { + return err + } if err := migrationMoveKeysToProviderConfig(ctx, db); err != nil { return err } @@ -1071,6 +1089,74 @@ func migrationMissingProviderColumnInKeyTable(ctx context.Context, db *gorm.DB) return nil } +// migrationAddToolsToAutoExecuteJSONColumn adds the tools_to_auto_execute_json column to the mcp_client table +func migrationAddToolsToAutoExecuteJSONColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_tools_to_auto_execute_json_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "tools_to_auto_execute_json") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "tools_to_auto_execute_json"); err != nil { + return err + } + // Initialize existing rows with empty array + if err := tx.Exec("UPDATE config_mcp_clients SET tools_to_auto_execute_json = '[]' WHERE tools_to_auto_execute_json IS NULL OR tools_to_auto_execute_json = ''").Error; err != nil { + return fmt.Errorf("failed to initialize tools_to_auto_execute_json: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableMCPClient{}, "tools_to_auto_execute_json"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +// migrationAddIsCodeModeClientColumn adds the is_code_mode_client column to the config_mcp_clients table +func migrationAddIsCodeModeClientColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_is_code_mode_client_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "is_code_mode_client") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "is_code_mode_client"); err != nil { + return err + } + // Initialize existing rows with false (default value) + if err := tx.Exec("UPDATE config_mcp_clients SET is_code_mode_client = false WHERE is_code_mode_client IS NULL").Error; err != nil { + return fmt.Errorf("failed to initialize is_code_mode_client: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableMCPClient{}, "is_code_mode_client"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + // migrationAddLogRetentionDaysColumn adds the log_retention_days column to the client config table func migrationAddLogRetentionDaysColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ @@ -1194,6 +1280,207 @@ func migrationAddBatchAndCachePricingColumns(ctx context.Context, db *gorm.DB) e return m.Migrate() } +func migrationAddMCPAgentDepthAndMCPToolExecutionTimeoutColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_agent_depth_and_mcp_tool_execution_timeout_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableClientConfig{}, "mcp_agent_depth") { + if err := migrator.AddColumn(&tables.TableClientConfig{}, "mcp_agent_depth"); err != nil { + return err + } + } + if !migrator.HasColumn(&tables.TableClientConfig{}, "mcp_tool_execution_timeout") { + if err := migrator.AddColumn(&tables.TableClientConfig{}, "mcp_tool_execution_timeout"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableClientConfig{}, "mcp_agent_depth"); err != nil { + return err + } + if err := migrator.DropColumn(&tables.TableClientConfig{}, "mcp_tool_execution_timeout"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +// migrationAddMCPCodeModeBindingLevelColumn adds the mcp_code_mode_binding_level column to the client config table. +// This column stores the code mode binding level preference (server or tool). +func migrationAddMCPCodeModeBindingLevelColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_code_mode_binding_level_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migratorInstance := tx.Migrator() + if !migratorInstance.HasColumn(&tables.TableClientConfig{}, "mcp_code_mode_binding_level") { + if err := migratorInstance.AddColumn(&tables.TableClientConfig{}, "mcp_code_mode_binding_level"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migratorInstance := tx.Migrator() + if err := migratorInstance.DropColumn(&tables.TableClientConfig{}, "mcp_code_mode_binding_level"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +// normalizeMCPClientName normalizes an MCP client name by: +// 1. Replacing hyphens and spaces with underscores +// 2. Removing leading digits +// 3. Using a default name if the result is empty +func normalizeMCPClientName(name string) string { + // Replace hyphens and spaces with underscores + normalized := strings.ReplaceAll(name, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + + // Remove leading digits + normalized = strings.TrimLeftFunc(normalized, func(r rune) bool { + return unicode.IsDigit(r) + }) + + // If name becomes empty after normalization, use a default name + if normalized == "" { + normalized = "mcp_client" + } + + return normalized +} + +// migrationNormalizeMCPClientNames normalizes MCP client names by: +// 1. Replacing hyphens and spaces with underscores +// 2. Removing leading digits +// 3. Adding number suffix if name already exists +func migrationNormalizeMCPClientNames(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "normalize_mcp_client_names", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + + // Fetch all MCP clients + var mcpClients []tables.TableMCPClient + if err := tx.Find(&mcpClients).Error; err != nil { + return fmt.Errorf("failed to fetch MCP clients: %w", err) + } + + // Track assigned names in memory to avoid transaction visibility issues + // and ensure we see all updates made during this migration + assignedNames := make(map[string]bool) + + // Helper function to find a unique name + findUniqueName := func(baseName string, originalName string, excludeID uint, tx *gorm.DB, assignedNames map[string]bool) (string, error) { + // First check if base name is already assigned in this migration + if !assignedNames[baseName] { + // Also check database for existing names (excluding current client) + var existing tables.TableMCPClient + err := tx.Where("name = ? AND id != ?", baseName, excludeID).First(&existing).Error + if err == gorm.ErrRecordNotFound { + // Name is available + assignedNames[baseName] = true + // Log normalization even when no collision + if originalName != baseName { + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, baseName) + } + return baseName, nil + } else if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + } + + // Name exists (either assigned in this migration or in database), try with number suffix starting from 2 + // (base name is conceptually "1", so collisions start from "2") + suffix := 2 + const maxSuffix = 1000 // Safety limit to prevent infinite loops + for { + if suffix > maxSuffix { + return "", fmt.Errorf("could not find unique name after %d attempts for base name: %s", maxSuffix, baseName) + } + candidateName := baseName + strconv.Itoa(suffix) + + // Check both in-memory map and database + if !assignedNames[candidateName] { + var existing tables.TableMCPClient + err := tx.Where("name = ? AND id != ?", candidateName, excludeID).First(&existing).Error + if err == gorm.ErrRecordNotFound { + // Found available name - log the transformation + assignedNames[candidateName] = true + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, candidateName) + return candidateName, nil + } else if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + } + suffix++ + } + } + + // Process each client + for _, client := range mcpClients { + originalName := client.Name + needsUpdate := false + + // Check if name needs normalization + if strings.Contains(originalName, "-") || strings.Contains(originalName, " ") { + needsUpdate = true + } else if len(originalName) > 0 && unicode.IsDigit(rune(originalName[0])) { + needsUpdate = true + } + + if needsUpdate { + // Normalize the name + normalizedName := normalizeMCPClientName(originalName) + + // Find a unique name (pass assignedNames map to track names in this migration) + uniqueName, err := findUniqueName(normalizedName, originalName, client.ID, tx, assignedNames) + if err != nil { + return fmt.Errorf("failed to find unique name for client %d (original: %s): %w", client.ID, originalName, err) + } + + // Update the client name + if err := tx.Model(&client).Update("name", uniqueName).Error; err != nil { + return fmt.Errorf("failed to update MCP client %d name from %s to %s: %w", client.ID, originalName, uniqueName, err) + } + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + // Rollback is not possible as we don't store the original names + // This migration is one-way + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running MCP client name normalization migration: %s", err.Error()) + } + return nil +} + // migrationMoveKeysToProviderConfig migrates keys from virtual key level to provider config level func migrationMoveKeysToProviderConfig(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ diff --git a/framework/configstore/migrations_test.go b/framework/configstore/migrations_test.go new file mode 100644 index 000000000..cada594b3 --- /dev/null +++ b/framework/configstore/migrations_test.go @@ -0,0 +1,539 @@ +package configstore + +import ( + "bytes" + "context" + "fmt" + "log" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// setupTestDB creates an in-memory SQLite database for testing +func setupTestDB(t *testing.T) *gorm.DB { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + require.NoError(t, err, "Failed to create test database") + + // Create the MCP clients table + err = db.AutoMigrate(&tables.TableMCPClient{}) + require.NoError(t, err, "Failed to migrate test database") + + return db +} + +// captureLogOutput captures log output during a function execution +func captureLogOutput(fn func()) string { + var buf bytes.Buffer + log.SetOutput(&buf) + defer log.SetOutput(os.Stderr) + + fn() + return buf.String() +} + +func TestNormalizeName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "hyphen to underscore", + input: "my-tool", + expected: "my_tool", + }, + { + name: "space to underscore", + input: "my tool", + expected: "my_tool", + }, + { + name: "multiple hyphens", + input: "my-super-tool", + expected: "my_super_tool", + }, + { + name: "multiple spaces", + input: "my super tool", + expected: "my_super_tool", + }, + { + name: "leading digits removed", + input: "123tool", + expected: "tool", + }, + { + name: "leading digits with hyphen", + input: "123my-tool", + expected: "my_tool", + }, + { + name: "empty after normalization", + input: "123", + expected: "mcp_client", + }, + { + name: "no change needed", + input: "my_tool", + expected: "my_tool", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + normalized := normalizeMCPClientName(tt.input) + assert.Equal(t, tt.expected, normalized, "normalizeMCPClientName should produce expected output") + }) + } +} + +func TestFindUniqueName_NoCollision(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create a test client with a unique name + client := &tables.TableMCPClient{ + Name: "existing_client", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + + // Test findUniqueName with a different base name (no collision) + logOutput := captureLogOutput(func() { + uniqueName, err := findUniqueNameForTest("new_client", "new_client", 999, db.WithContext(ctx)) + require.NoError(t, err) + assert.Equal(t, "new_client", uniqueName, "Should return base name when no collision") + }) + + // Should not log anything when there's no collision + assert.Empty(t, logOutput, "Should not log when name is available without suffix") +} + +func TestFindUniqueName_WithCollision(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create existing clients that will cause collisions + // First client with base name + client1 := &tables.TableMCPClient{ + Name: "my_tool", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client1).Error + require.NoError(t, err) + + // Second client with first suffix + client2 := &tables.TableMCPClient{ + Name: "my_tool1", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = db.WithContext(ctx).Create(client2).Error + require.NoError(t, err) + + // Test findUniqueName with collision - should find "my_tool2" + // excludeID is set to a non-existent ID (999) so all existing clients are considered + var uniqueName string + logOutput := captureLogOutput(func() { + uniqueName, err = findUniqueNameForTest("my_tool", "my-tool", 999, db.WithContext(ctx)) + }) + + require.NoError(t, err) + assert.Equal(t, "my_tool2", uniqueName, "Should return name with suffix when collision occurs") + assert.Contains(t, logOutput, "MCP Client Name Normalized: 'my-tool' -> 'my_tool2'", "Should log the transformation") +} + +func TestFindUniqueName_MultipleCollisions(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create existing clients that will cause multiple collisions + client1 := &tables.TableMCPClient{ + Name: "test_tool", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client1).Error + require.NoError(t, err) + + client2 := &tables.TableMCPClient{ + Name: "test_tool1", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = db.WithContext(ctx).Create(client2).Error + require.NoError(t, err) + + client3 := &tables.TableMCPClient{ + Name: "test_tool2", + ClientID: "client-3", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = db.WithContext(ctx).Create(client3).Error + require.NoError(t, err) + + // Test findUniqueName with multiple collisions - should find "test_tool3" + var uniqueName string + logOutput := captureLogOutput(func() { + uniqueName, err = findUniqueNameForTest("test_tool", "test tool", 999, db.WithContext(ctx)) + }) + + require.NoError(t, err) + assert.Equal(t, "test_tool3", uniqueName, "Should return name with correct suffix after multiple collisions") + assert.Contains(t, logOutput, "MCP Client Name Normalized: 'test tool' -> 'test_tool3'", "Should log the transformation") +} + +func TestFindUniqueName_NormalizationAndCollision(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create existing client with normalized name + client := &tables.TableMCPClient{ + Name: "my_tool", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + + // Test that "my-tool" normalizes to "my_tool" and then collides, requiring suffix + var uniqueName string + logOutput := captureLogOutput(func() { + uniqueName, err = findUniqueNameForTest("my_tool", "my-tool", 999, db.WithContext(ctx)) + }) + + require.NoError(t, err) + assert.Equal(t, "my_tool2", uniqueName, "Should handle normalization and collision") + assert.Contains(t, logOutput, "MCP Client Name Normalized: 'my-tool' -> 'my_tool2'", "Should log the full transformation") +} + +func TestFindUniqueName_MultipleNormalizationsToSameBase(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Test case: 3 entries that normalize to the same base name: + // "mcp client" -> "mcp_client" + // "mcp-client" -> "mcp_client" (collision, becomes "mcp_client2") + // "1mcp-client" -> "mcp_client" (collision, becomes "mcp_client3") + // Note: In the actual migration, names are processed sequentially and each checks + // against all previously created names. To simulate this, we need to create clients + // with the original names first, then normalize them in sequence. + + // Helper function to normalize (same logic as in migrations.go) + normalizeName := func(name string) string { + normalized := strings.ReplaceAll(name, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + normalized = strings.TrimLeftFunc(normalized, func(r rune) bool { + return r >= '0' && r <= '9' + }) + if normalized == "" { + normalized = "mcp_client" + } + return normalized + } + + // Create three clients with original names (simulating pre-migration state) + clients := []*tables.TableMCPClient{ + { + Name: "mcp client", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "mcp-client", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "1mcp-client", + ClientID: "client-3", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + } + + for _, client := range clients { + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + } + + // Now simulate the migration: process each client sequentially + // First: "mcp client" -> "mcp_client" (no collision) + client1 := clients[0] + normalizedName1 := normalizeName(client1.Name) + var uniqueName1 string + var err error + logOutput1 := captureLogOutput(func() { + uniqueName1, err = findUniqueNameForTest(normalizedName1, client1.Name, client1.ID, db.WithContext(ctx)) + }) + require.NoError(t, err) + assert.Equal(t, "mcp_client", uniqueName1, "First normalization should use base name") + assert.Empty(t, logOutput1, "Should not log when name is available without suffix") + + // Update first client + err = db.WithContext(ctx).Model(client1).Update("name", uniqueName1).Error + require.NoError(t, err) + + // Second: "mcp-client" -> "mcp_client" (collision with "mcp_client", becomes "mcp_client2") + // Note: We need to check that "mcp_client" exists (from client1), so it should skip to "mcp_client2" + client2 := clients[1] + normalizedName2 := normalizeName(client2.Name) + var uniqueName2 string + logOutput2 := captureLogOutput(func() { + uniqueName2, err = findUniqueNameForTest(normalizedName2, client2.Name, client2.ID, db.WithContext(ctx)) + }) + require.NoError(t, err) + // With the updated implementation, suffixes start from 2 when base name exists + // So "mcp-client" normalizes to "mcp_client" which collides, becomes "mcp_client2" + assert.Equal(t, "mcp_client2", uniqueName2, "Second normalization should get suffix 2 (skipping 1)") + assert.Contains(t, logOutput2, "MCP Client Name Normalized: 'mcp-client' -> 'mcp_client2'", "Should log the transformation") + + // Update second client + err = db.WithContext(ctx).Model(client2).Update("name", uniqueName2).Error + require.NoError(t, err) + + // Third: "1mcp-client" -> "mcp_client" (collision with "mcp_client" and "mcp_client2", becomes "mcp_client3") + client3 := clients[2] + normalizedName3 := normalizeName(client3.Name) + var uniqueName3 string + logOutput3 := captureLogOutput(func() { + uniqueName3, err = findUniqueNameForTest(normalizedName3, client3.Name, client3.ID, db.WithContext(ctx)) + }) + require.NoError(t, err) + // Third normalization finds "mcp_client" and "mcp_client2" exist, so becomes "mcp_client3" + assert.Equal(t, "mcp_client3", uniqueName3, "Third normalization should get suffix 3") + assert.Contains(t, logOutput3, "MCP Client Name Normalized: '1mcp-client' -> 'mcp_client3'", "Should log the transformation") + + // Update third client + err = db.WithContext(ctx).Model(client3).Update("name", uniqueName3).Error + require.NoError(t, err) + + // Final verification: all three should exist with correct names + var finalClients []tables.TableMCPClient + err = db.WithContext(ctx).Find(&finalClients).Error + require.NoError(t, err) + assert.Len(t, finalClients, 3, "Should have all 3 clients") + + names := make([]string, len(finalClients)) + for i, c := range finalClients { + names[i] = c.Name + } + assert.Contains(t, names, "mcp_client", "Should contain mcp_client") + assert.Contains(t, names, "mcp_client2", "Should contain mcp_client2") + assert.Contains(t, names, "mcp_client3", "Should contain mcp_client3") +} + +func TestFindUniqueName_MigrationScenarioWithInMemoryTracking(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // This test simulates the exact migration scenario where clients are processed in a loop + // and we need to track assigned names in memory to avoid transaction visibility issues + + // Create three clients with original names (simulating pre-migration state) + clients := []*tables.TableMCPClient{ + { + Name: "mcp client", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "mcp-client", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "1mcp-client", + ClientID: "client-3", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + } + + for _, client := range clients { + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + } + + // Simulate the migration: process clients in a loop with in-memory tracking + assignedNames := make(map[string]bool) + normalizeName := func(name string) string { + normalized := strings.ReplaceAll(name, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + normalized = strings.TrimLeftFunc(normalized, func(r rune) bool { + return r >= '0' && r <= '9' + }) + if normalized == "" { + normalized = "mcp_client" + } + return normalized + } + + var logOutputs []string + for _, client := range clients { + originalName := client.Name + needsUpdate := strings.Contains(originalName, "-") || strings.Contains(originalName, " ") || + (len(originalName) > 0 && originalName[0] >= '0' && originalName[0] <= '9') + + if needsUpdate { + normalizedName := normalizeName(originalName) + uniqueName, err := findUniqueNameForTestWithTracking(normalizedName, originalName, client.ID, db.WithContext(ctx), assignedNames) + require.NoError(t, err) + + // Capture log output + logOutput := captureLogOutput(func() { + // Log if name changed + if originalName != uniqueName { + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, uniqueName) + } + }) + if logOutput != "" { + logOutputs = append(logOutputs, logOutput) + } + + // Update client + err = db.WithContext(ctx).Model(client).Update("name", uniqueName).Error + require.NoError(t, err) + } + } + + // Verify all three clients have correct names + var finalClients []tables.TableMCPClient + err := db.WithContext(ctx).Find(&finalClients).Error + require.NoError(t, err) + assert.Len(t, finalClients, 3, "Should have all 3 clients") + + names := make([]string, len(finalClients)) + for i, c := range finalClients { + names[i] = c.Name + } + assert.Contains(t, names, "mcp_client", "Should contain mcp_client") + assert.Contains(t, names, "mcp_client2", "Should contain mcp_client2") + assert.Contains(t, names, "mcp_client3", "Should contain mcp_client3") + + // Verify logging: should log all three transformations + allLogs := strings.Join(logOutputs, "") + assert.Contains(t, allLogs, "MCP Client Name Normalized: 'mcp client' -> 'mcp_client'", "Should log first normalization") + assert.Contains(t, allLogs, "MCP Client Name Normalized: 'mcp-client' -> 'mcp_client2'", "Should log second normalization") + assert.Contains(t, allLogs, "MCP Client Name Normalized: '1mcp-client' -> 'mcp_client3'", "Should log third normalization") +} + +// findUniqueNameForTestWithTracking is a test helper that tracks assigned names in memory +func findUniqueNameForTestWithTracking(baseName string, originalName string, excludeID uint, tx *gorm.DB, assignedNames map[string]bool) (string, error) { + // First check if base name is already assigned in this migration + if !assignedNames[baseName] { + // Also check database for existing names (excluding current client) + var count int64 + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", baseName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Name is available + assignedNames[baseName] = true + // Log normalization even when no collision + if originalName != baseName { + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, baseName) + } + return baseName, nil + } + } + + // Name exists (either assigned in this migration or in database), try with number suffix starting from 2 + suffix := 2 + const maxSuffix = 1000 + for { + if suffix > maxSuffix { + return "", fmt.Errorf("could not find unique name after %d attempts for base name: %s", maxSuffix, baseName) + } + candidateName := baseName + strconv.Itoa(suffix) + + // Check both in-memory map and database + if !assignedNames[candidateName] { + var count int64 + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", candidateName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Found available name + assignedNames[candidateName] = true + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, candidateName) + return candidateName, nil + } + } + suffix++ + } +} + +// findUniqueNameForTest is a test helper that extracts the findUniqueName logic +// This mirrors the implementation in migrations.go for testing +func findUniqueNameForTest(baseName string, originalName string, excludeID uint, tx *gorm.DB) (string, error) { + // First, try the base name + var count int64 + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", baseName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Name is available + return baseName, nil + } + + // Name exists, try with number suffix starting from 2 + // (base name is conceptually "1", so collisions start from "2") + suffix := 2 + const maxSuffix = 1000 // Safety limit to prevent infinite loops + for { + if suffix > maxSuffix { + return "", fmt.Errorf("could not find unique name after %d attempts for base name: %s", maxSuffix, baseName) + } + candidateName := baseName + strconv.Itoa(suffix) + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", candidateName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Found available name - log the transformation + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, candidateName) + return candidateName, nil + } + suffix++ + } +} diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 472400c7f..e420cf328 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/bytedance/sonic" bifrost "github.com/maximhq/bifrost/core" @@ -41,6 +42,9 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC AllowedOrigins: config.AllowedOrigins, MaxRequestBodySizeMB: config.MaxRequestBodySizeMB, EnableLiteLLMFallbacks: config.EnableLiteLLMFallbacks, + MCPAgentDepth: config.MCPAgentDepth, + MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, + MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel, } // Delete existing client config and create new one in a transaction return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { @@ -158,7 +162,6 @@ func (s *RDBConfigStore) UpdateFrameworkConfig(ctx context.Context, config *tabl } return tx.Create(config).Error }) - } // GetFrameworkConfig retrieves the framework configuration from the database. @@ -195,6 +198,9 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er AllowedOrigins: dbConfig.AllowedOrigins, MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, EnableLiteLLMFallbacks: dbConfig.EnableLiteLLMFallbacks, + MCPAgentDepth: dbConfig.MCPAgentDepth, + MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, + MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel, }, nil } @@ -570,7 +576,7 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo return err } - // Delete the provider (keys will be deleted due to CASCADE constraint) + // Delete the provider first (keys will be deleted due to CASCADE constraint) if err := txDB.WithContext(ctx).Delete(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound @@ -585,9 +591,6 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error) { var dbProviders []tables.TableProvider if err := s.db.WithContext(ctx).Preload("Keys").Find(&dbProviders).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } if len(dbProviders) == 0 { @@ -665,7 +668,7 @@ func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.Mo if processedARN, err := envutils.ProcessEnvValue(*bedrockConfig.ARN); err == nil { bedrockConfigCopy.ARN = &processedARN } - } + } bedrockConfig = &bedrockConfigCopy } @@ -734,17 +737,40 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, } clientConfigs[i] = schemas.MCPClientConfig{ - ID: dbClient.ClientID, - Name: dbClient.Name, - ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), - ConnectionString: processedConnectionString, - StdioConfig: dbClient.StdioConfig, - ToolsToExecute: dbClient.ToolsToExecute, - Headers: processedHeaders, + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: processedConnectionString, + StdioConfig: dbClient.StdioConfig, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: processedHeaders, + } + } + var clientConfig tables.TableClientConfig + if err := s.db.WithContext(ctx).First(&clientConfig).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + // Return MCP config with default ToolManagerConfig if no client config exists + // This will never happen, but just in case. + return &schemas.MCPConfig{ + ClientConfigs: clientConfigs, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: 30 * time.Second, // default from TableClientConfig + MaxAgentDepth: 10, // default from TableClientConfig + }, + }, nil } + return nil, err + } + toolManagerConfig := schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: time.Duration(clientConfig.MCPToolExecutionTimeout) * time.Second, + MaxAgentDepth: clientConfig.MCPAgentDepth, + CodeModeBindingLevel: schemas.CodeModeBindingLevel(clientConfig.MCPCodeModeBindingLevel), } return &schemas.MCPConfig{ - ClientConfigs: clientConfigs, + ClientConfigs: clientConfigs, + ToolManagerConfig: &toolManagerConfig, }, nil } @@ -770,17 +796,20 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig } // Substitute environment variables back to their original form - substituteMCPClientEnvVars(&clientConfigCopy, envKeys) + // For create operations, no existing headers to restore from + substituteMCPClientEnvVars(&clientConfigCopy, envKeys, nil) // Create new client dbClient := tables.TableMCPClient{ - ClientID: clientConfigCopy.ID, - Name: clientConfigCopy.Name, - ConnectionType: string(clientConfigCopy.ConnectionType), - ConnectionString: clientConfigCopy.ConnectionString, - StdioConfig: clientConfigCopy.StdioConfig, - ToolsToExecute: clientConfigCopy.ToolsToExecute, - Headers: clientConfigCopy.Headers, + ClientID: clientConfigCopy.ID, + Name: clientConfigCopy.Name, + IsCodeModeClient: clientConfigCopy.IsCodeModeClient, + ConnectionType: string(clientConfigCopy.ConnectionType), + ConnectionString: clientConfigCopy.ConnectionString, + StdioConfig: clientConfigCopy.StdioConfig, + ToolsToExecute: clientConfigCopy.ToolsToExecute, + ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, + Headers: clientConfigCopy.Headers, } if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil { @@ -809,17 +838,20 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c } // Substitute environment variables back to their original form - substituteMCPClientEnvVars(&clientConfigCopy, envKeys) + // Pass existing headers to restore redacted plain values + substituteMCPClientEnvVars(&clientConfigCopy, envKeys, existingClient.Headers) // Update existing client existingClient.Name = clientConfigCopy.Name - existingClient.ConnectionType = string(clientConfigCopy.ConnectionType) - existingClient.ConnectionString = clientConfigCopy.ConnectionString - existingClient.StdioConfig = clientConfigCopy.StdioConfig + existingClient.IsCodeModeClient = clientConfigCopy.IsCodeModeClient existingClient.ToolsToExecute = clientConfigCopy.ToolsToExecute + existingClient.ToolsToAutoExecute = clientConfigCopy.ToolsToAutoExecute existingClient.Headers = clientConfigCopy.Headers - if err := tx.WithContext(ctx).Updates(&existingClient).Error; err != nil { + // Use Select to explicitly include IsCodeModeClient even when it's false (zero value) + // GORM's Updates() skips zero values by default, so we need to explicitly select fields + // Using struct field names - GORM will convert them to column names automatically + if err := tx.WithContext(ctx).Select("name", "is_code_mode_client", "tools_to_execute_json", "tools_to_auto_execute_json", "headers_json", "updated_at").Updates(&existingClient).Error; err != nil { return s.parseGormError(err) } return nil @@ -928,9 +960,6 @@ func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logs func (s *RDBConfigStore) GetEnvKeys(ctx context.Context) (map[string][]EnvKeyInfo, error) { var dbEnvKeys []tables.TableEnvKey if err := s.db.WithContext(ctx).Find(&dbEnvKeys).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } envKeys := make(map[string][]EnvKeyInfo) @@ -1365,7 +1394,80 @@ func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ( // DeleteVirtualKey deletes a virtual key from the database. func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error + if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var virtualKey tables.TableVirtualKey + if err := tx.WithContext(ctx).Preload("ProviderConfigs").First(&virtualKey, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + + // Collect budget and rate limit IDs from provider configs before deletion + var providerConfigBudgetIDs []string + var providerConfigRateLimitIDs []string + for _, pc := range virtualKey.ProviderConfigs { + // Delete the keys join table entries + if err := tx.WithContext(ctx).Exec("DELETE FROM governance_virtual_key_provider_config_keys WHERE table_virtual_key_provider_config_id = ?", pc.ID).Error; err != nil { + return err + } + // Collect budget and rate limit IDs for deletion after provider config + if pc.BudgetID != nil { + providerConfigBudgetIDs = append(providerConfigBudgetIDs, *pc.BudgetID) + } + if pc.RateLimitID != nil { + providerConfigRateLimitIDs = append(providerConfigRateLimitIDs, *pc.RateLimitID) + } + } + + // Delete all provider configs associated with the virtual key first + if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "virtual_key_id = ?", id).Error; err != nil { + return err + } + // Now delete the collected budgets and rate limits + for _, budgetID := range providerConfigBudgetIDs { + if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", budgetID).Error; err != nil { + return err + } + } + for _, rateLimitID := range providerConfigRateLimitIDs { + if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", rateLimitID).Error; err != nil { + return err + } + } + // Delete all MCP configs associated with the virtual key + if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "virtual_key_id = ?", id).Error; err != nil { + return err + } + // Delete the budget associated with the virtual key + budgetID := virtualKey.BudgetID + rateLimitID := virtualKey.RateLimitID + // Delete the virtual key + if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + if budgetID != nil { + if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + // Delete the rate limit associated with the virtual key + if rateLimitID != nil { + if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { + return err + } + } + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + return nil } // GetVirtualKeyProviderConfigs retrieves all virtual key provider configs from the database. @@ -1447,7 +1549,34 @@ func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id } else { txDB = s.db } - return txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "id = ?", id).Error + // First fetch the provider config to get budget and rate limit IDs + var providerConfig tables.TableVirtualKeyProviderConfig + if err := txDB.WithContext(ctx).First(&providerConfig, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Store the budget and rate limit IDs before deleting + budgetID := providerConfig.BudgetID + rateLimitID := providerConfig.RateLimitID + // Delete the provider config first + if err := txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "id = ?", id).Error; err != nil { + return err + } + // Delete the budget if it exists + if budgetID != nil { + if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + // Delete the rate limit if it exists + if rateLimitID != nil { + if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { + return err + } + } + return nil } // GetVirtualKeyMCPConfigs retrieves all virtual key MCP configs from the database. @@ -1518,9 +1647,6 @@ func (s *RDBConfigStore) GetTeams(ctx context.Context, customerID string) ([]tab } var teams []tables.TableTeam if err := query.Find(&teams).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } return teams, nil @@ -1568,16 +1694,47 @@ func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam, // DeleteTeam deletes a team from the database. func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Delete(&tables.TableTeam{}, "id = ?", id).Error + if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var team tables.TableTeam + if err := tx.WithContext(ctx).Preload("Budget").First(&team, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Set team_id to null for all virtual keys associated with the team + if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("team_id = ?", id).Update("team_id", nil).Error; err != nil { + return err + } + // Store the budget ID before deleting the team + budgetID := team.BudgetID + // Delete the team first + if err := tx.WithContext(ctx).Delete(&tables.TableTeam{}, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Delete the team's budget if it exists + if budgetID != nil { + if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + return nil } // GetCustomers retrieves all customers from the database. func (s *RDBConfigStore) GetCustomers(ctx context.Context) ([]tables.TableCustomer, error) { var customers []tables.TableCustomer if err := s.db.WithContext(ctx).Preload("Teams").Preload("Budget").Find(&customers).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } return customers, nil @@ -1625,7 +1782,54 @@ func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.Ta // DeleteCustomer deletes a customer from the database. func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Delete(&tables.TableCustomer{}, "id = ?", id).Error + if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var customer tables.TableCustomer + if err := tx.WithContext(ctx).Preload("Budget").First(&customer, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Set customer_id to null for all virtual keys associated with the customer + if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil { + return err + } + // Set customer_id to null for all teams associated with the customer + if err := tx.WithContext(ctx).Model(&tables.TableTeam{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil { + return err + } + // Store the budget ID before deleting the customer + budgetID := customer.BudgetID + // Delete the customer first + if err := tx.WithContext(ctx).Delete(&tables.TableCustomer{}, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Delete the customer's budget if it exists + if budgetID != nil { + if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + return nil +} + +// GetRateLimits retrieves all rate limits from the database. +func (s *RDBConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) { + var rateLimits []tables.TableRateLimit + if err := s.db.WithContext(ctx).Find(&rateLimits).Error; err != nil { + return nil, err + } + return rateLimits, nil } // GetRateLimit retrieves a specific rate limit from the database. @@ -1688,9 +1892,6 @@ func (s *RDBConfigStore) UpdateRateLimits(ctx context.Context, rateLimits []*tab func (s *RDBConfigStore) GetBudgets(ctx context.Context) ([]tables.TableBudget, error) { var budgets []tables.TableBudget if err := s.db.WithContext(ctx).Find(&budgets).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } return budgets, nil @@ -1976,6 +2177,33 @@ func (s *RDBConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gor return s.db.WithContext(ctx).Transaction(fn) } +// RetryOnNotFound retries a function up to 3 times with 1-second delays if it returns ErrNotFound +func (s *RDBConfigStore) RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) { + var lastErr error + for attempt := range maxRetries { + result, err := fn(ctx) + if err == nil { + return result, nil + } + if !errors.Is(err, ErrNotFound) && !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, err + } + + lastErr = err + + // Don't wait after the last attempt + if attempt < maxRetries-1 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelay): + // Continue to next retry + } + } + } + return nil, lastErr +} + // doesTableExist checks if a table exists in the database. func (s *RDBConfigStore) doesTableExist(ctx context.Context, tableName string) bool { return s.db.WithContext(ctx).Migrator().HasTable(tableName) diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 04196c03f..ade2d3a78 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -4,6 +4,7 @@ package configstore import ( "context" "fmt" + "time" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore/tables" @@ -100,6 +101,7 @@ type ConfigStore interface { DeleteCustomer(ctx context.Context, id string) error // Rate limit CRUD + GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) GetRateLimit(ctx context.Context, id string) (*tables.TableRateLimit, error) CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error @@ -141,6 +143,9 @@ type ConfigStore interface { // Generic transaction manager ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error + // Not found retry wrapper + RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) + // DB returns the underlying database connection. DB() *gorm.DB diff --git a/framework/configstore/tables/budget.go b/framework/configstore/tables/budget.go index 1744363b2..543d0cb57 100644 --- a/framework/configstore/tables/budget.go +++ b/framework/configstore/tables/budget.go @@ -21,17 +21,20 @@ type TableBudget struct { CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + // Virtual fields for runtime use (not stored in DB) + LastDBUsage float64 `gorm:"-" json:"-"` } // TableName sets the table name for each model func (TableBudget) TableName() string { return "governance_budgets" } // BeforeSave hook for Budget to validate reset duration format and max limit -func (b *TableBudget) BeforeSave(tx *gorm.DB) error { +func (b *TableBudget) BeforeSave(tx *gorm.DB) error { // Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y") if d, err := ParseDuration(b.ResetDuration); err != nil { return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration) - }else if d <= 0 { + } else if d <= 0 { return fmt.Errorf("reset duration must be > 0: %s", b.ResetDuration) } // Validate that MaxLimit is not negative (budgets should be positive) @@ -41,3 +44,9 @@ func (b *TableBudget) BeforeSave(tx *gorm.DB) error { return nil } + +// AfterFind hook for Budget to set the LastDBUsage virtual field +func (b *TableBudget) AfterFind(tx *gorm.DB) error { + b.LastDBUsage = b.CurrentUsage + return nil +} \ No newline at end of file diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index 536d6a565..395277711 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -15,12 +15,16 @@ type TableClientConfig struct { AllowedOriginsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string InitialPoolSize int `gorm:"default:300" json:"initial_pool_size"` EnableLogging bool `gorm:"" json:"enable_logging"` - DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged - LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) + DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged + LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) EnableGovernance bool `gorm:"" json:"enable_governance"` EnforceGovernanceHeader bool `gorm:"" json:"enforce_governance_header"` AllowDirectKeys bool `gorm:"" json:"allow_direct_keys"` MaxRequestBodySizeMB int `gorm:"default:100" json:"max_request_body_size_mb"` + MCPAgentDepth int `gorm:"default:10" json:"mcp_agent_depth"` + MCPToolExecutionTimeout int `gorm:"default:30" json:"mcp_tool_execution_timeout"` // Timeout for individual tool execution in seconds (default: 30) + MCPCodeModeBindingLevel string `gorm:"default:server" json:"mcp_code_mode_binding_level"` // How tools are exposed in VFS: "server" or "tool" + // LiteLLM fallback flag EnableLiteLLMFallbacks bool `gorm:"column:enable_litellm_fallbacks;default:false" json:"enable_litellm_fallbacks"` @@ -33,7 +37,7 @@ type TableClientConfig struct { // Virtual fields for runtime use (not stored in DB) PrometheusLabels []string `gorm:"-" json:"prometheus_labels"` - AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"` + AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"` } // TableName sets the table name for each model diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index f5c2381a6..b991a545a 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -10,14 +10,16 @@ import ( // TableMCPClient represents an MCP client configuration in the database type TableMCPClient struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. - ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` - Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` - ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType - ConnectionString *string `gorm:"type:text" json:"connection_string,omitempty"` - StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig - ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. + ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` + Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` + IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client + ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType + ConnectionString *string `gorm:"type:text" json:"connection_string,omitempty"` + StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig + ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash @@ -27,9 +29,10 @@ type TableMCPClient struct { UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` // Virtual fields for runtime use (not stored in DB) - StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` - ToolsToExecute []string `gorm:"-" json:"tools_to_execute"` - Headers map[string]string `gorm:"-" json:"headers"` + StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` + ToolsToExecute []string `gorm:"-" json:"tools_to_execute"` + ToolsToAutoExecute []string `gorm:"-" json:"tools_to_auto_execute"` + Headers map[string]string `gorm:"-" json:"headers"` } // TableName sets the table name for each model @@ -57,6 +60,16 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { c.ToolsToExecuteJSON = "[]" } + if c.ToolsToAutoExecute != nil { + data, err := json.Marshal(c.ToolsToAutoExecute) + if err != nil { + return err + } + c.ToolsToAutoExecuteJSON = string(data) + } else { + c.ToolsToAutoExecuteJSON = "[]" + } + if c.Headers != nil { data, err := json.Marshal(c.Headers) if err != nil { @@ -86,6 +99,12 @@ func (c *TableMCPClient) AfterFind(tx *gorm.DB) error { } } + if c.ToolsToAutoExecuteJSON != "" { + if err := json.Unmarshal([]byte(c.ToolsToAutoExecuteJSON), &c.ToolsToAutoExecute); err != nil { + return err + } + } + if c.HeadersJSON != "" { if err := json.Unmarshal([]byte(c.HeadersJSON), &c.Headers); err != nil { return err diff --git a/framework/configstore/tables/ratelimit.go b/framework/configstore/tables/ratelimit.go index 7147e7b89..1a46c690e 100644 --- a/framework/configstore/tables/ratelimit.go +++ b/framework/configstore/tables/ratelimit.go @@ -29,6 +29,10 @@ type TableRateLimit struct { CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + // Virtual fields for runtime use (not stored in DB) + LastDBTokenUsage int64 `gorm:"-" json:"-"` + LastDBRequestUsage int64 `gorm:"-" json:"-"` } // TableName sets the table name for each model @@ -75,3 +79,10 @@ func (rl *TableRateLimit) BeforeSave(tx *gorm.DB) error { return nil } + +// AfterFind hook for RateLimit to set the LastDBTokenUsage and LastDBRequestUsage virtual fields +func (rl *TableRateLimit) AfterFind(tx *gorm.DB) error { + rl.LastDBTokenUsage = rl.TokenCurrentUsage + rl.LastDBRequestUsage = rl.RequestCurrentUsage + return nil +} \ No newline at end of file diff --git a/framework/configstore/utils.go b/framework/configstore/utils.go index 78f0e133e..24f1f636c 100644 --- a/framework/configstore/utils.go +++ b/framework/configstore/utils.go @@ -183,32 +183,59 @@ func substituteMCPEnvVars(config *schemas.MCPConfig, envKeys map[string][]EnvKey } // substituteMCPClientEnvVars replaces resolved environment variable values with their original env.VAR_NAME references for a single MCP client config -func substituteMCPClientEnvVars(clientConfig *schemas.MCPClientConfig, envKeys map[string][]EnvKeyInfo) { +// If existingHeaders is provided, it will restore redacted plain header values from the existing headers before substitution +func substituteMCPClientEnvVars(clientConfig *schemas.MCPClientConfig, envKeys map[string][]EnvKeyInfo, existingHeaders map[string]string) { + // First, restore redacted plain header values from existing headers if provided + // This handles the case where UI sends redacted headers that aren't env vars + if existingHeaders != nil && clientConfig.Headers != nil { + for header, value := range clientConfig.Headers { + // Check if the value is redacted (contains **** pattern) and not an env var + if strings.Contains(value, "****") && !strings.HasPrefix(value, "env.") { + // If header exists in existing headers and wasn't an env var, restore it + if oldHeaderValue, exists := existingHeaders[header]; exists { + if !strings.HasPrefix(oldHeaderValue, "env.") { + clientConfig.Headers[header] = oldHeaderValue + } + } + } + } + } + // Find the environment variable for this client's connection string and headers for envVar, keyInfos := range envKeys { for _, keyInfo := range keyInfos { // For MCP connection strings if keyInfo.KeyType == "connection_string" { - // Extract client name from config path like "mcp.client_configs.clientName.connection_string" + // Extract client ID from config path like "mcp.client_configs.clientID.connection_string" pathParts := strings.Split(keyInfo.ConfigPath, ".") if len(pathParts) >= 3 && pathParts[0] == "mcp" && pathParts[1] == "client_configs" { - clientName := pathParts[2] - // If this environment variable is for the current client - if clientName == clientConfig.Name && clientConfig.ConnectionString != nil { + clientID := pathParts[2] + // If this environment variable is for the current client (match by ID) + if clientID == clientConfig.ID && clientConfig.ConnectionString != nil { clientConfig.ConnectionString = &[]string{fmt.Sprintf("env.%s", envVar)}[0] } } } // For MCP headers if keyInfo.KeyType == "mcp_header" { - // Extract client name and header name from config path like "mcp.client_configs.clientName.headers.headerName" + // Extract client ID and header name from config path like "mcp.client_configs.clientID.headers.headerName" pathParts := strings.Split(keyInfo.ConfigPath, ".") if len(pathParts) >= 5 && pathParts[0] == "mcp" && pathParts[1] == "client_configs" && pathParts[3] == "headers" { - clientName := pathParts[2] + clientID := pathParts[2] headerName := pathParts[4] - // If this environment variable is for the current client - if clientName == clientConfig.Name && clientConfig.Headers != nil { - clientConfig.Headers[headerName] = fmt.Sprintf("env.%s", envVar) + // If this environment variable is for the current client (match by ID) + if clientID == clientConfig.ID && clientConfig.Headers != nil { + if headerValue, exists := clientConfig.Headers[headerName]; exists { + // If it's already in env.VAR format, update to use the correct env var + if strings.HasPrefix(headerValue, "env.") { + clientConfig.Headers[headerName] = fmt.Sprintf("env.%s", envVar) + } else if strings.Contains(headerValue, "****") { + // If it's redacted (contains ****), restore to env.VAR format + // This handles the case where UI sends redacted headers back for env vars + clientConfig.Headers[headerName] = fmt.Sprintf("env.%s", envVar) + } + // If it's a plain value (not env. and not redacted), leave it as-is + } } } } diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index ad3997e1f..64bba62ff 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -29,6 +29,8 @@ type ModelCatalog struct { pricingSyncInterval time.Duration pricingMu sync.RWMutex + shouldSyncPricingFunc ShouldSyncPricingFunc + // In-memory cache for fast access - direct map for O(1) lookups pricingData map[string]configstoreTables.TableModelPricing mu sync.RWMutex @@ -76,8 +78,14 @@ type PricingEntry struct { OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` } +// ShouldSyncPricingFunc is a function that determines if pricing data should be synced +// It returns a boolean indicating if syncing is needed +// It is completely optional and can be nil if not needed +// syncPricing function will be called if this function returns true +type ShouldSyncPricingFunc func(ctx context.Context) bool + // Init initializes the pricing manager -func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, logger schemas.Logger) (*ModelCatalog, error) { +func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, shouldSyncPricingFunc ShouldSyncPricingFunc, logger schemas.Logger) (*ModelCatalog, error) { // Initialize pricing URL and sync interval pricingURL := DefaultPricingURL if config.PricingURL != nil { @@ -89,13 +97,14 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto } mc := &ModelCatalog{ - pricingURL: pricingURL, - pricingSyncInterval: pricingSyncInterval, - configStore: configStore, - logger: logger, - pricingData: make(map[string]configstoreTables.TableModelPricing), - modelPool: make(map[schemas.ModelProvider][]string), - done: make(chan struct{}), + pricingURL: pricingURL, + pricingSyncInterval: pricingSyncInterval, + configStore: configStore, + logger: logger, + pricingData: make(map[string]configstoreTables.TableModelPricing), + modelPool: make(map[schemas.ModelProvider][]string), + done: make(chan struct{}), + shouldSyncPricingFunc: shouldSyncPricingFunc, } logger.Info("initializing pricing manager...") diff --git a/framework/modelcatalog/sync.go b/framework/modelcatalog/sync.go index 7f81cbae6..ac94d7e73 100644 --- a/framework/modelcatalog/sync.go +++ b/framework/modelcatalog/sync.go @@ -58,6 +58,13 @@ func (mc *ModelCatalog) shouldSyncPricing(ctx context.Context) (bool, string) { func (mc *ModelCatalog) syncPricing(ctx context.Context) error { mc.logger.Debug("starting pricing data synchronization for governance") + if mc.shouldSyncPricingFunc != nil { + if !mc.shouldSyncPricingFunc(ctx) { + mc.logger.Debug("pricing sync cancelled by custom function") + return nil + } + } + // Load pricing data from URL pricingData, err := mc.loadPricingFromURL(ctx) if err != nil { diff --git a/plugins/governance/advanced_scenarios_test.go b/plugins/governance/advanced_scenarios_test.go new file mode 100644 index 000000000..5d56f7eb2 --- /dev/null +++ b/plugins/governance/advanced_scenarios_test.go @@ -0,0 +1,1675 @@ +package governance + +import ( + "testing" + "time" +) + +// ============================================================================ +// SCENARIO 1: VK Switching Teams After Budget Exhaustion +// ============================================================================ + +// TestVKSwitchTeamAfterBudgetExhaustion verifies that after exhausting one team's budget, +// switching the VK to another team allows requests to pass +func TestVKSwitchTeamAfterBudgetExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create Team 1 with small budget + team1Name := "test-team1-switch-" + generateRandomID() + team1Budget := 0.01 // $0.01 + createTeam1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: team1Name, + Budget: &BudgetRequest{ + MaxLimit: team1Budget, + ResetDuration: "1h", + }, + }, + }) + + if createTeam1Resp.StatusCode != 200 { + t.Fatalf("Failed to create team1: status %d", createTeam1Resp.StatusCode) + } + + team1ID := ExtractIDFromResponse(t, createTeam1Resp, "id") + testData.AddTeam(team1ID) + + // Create Team 2 with higher budget + team2Name := "test-team2-switch-" + generateRandomID() + team2Budget := 10.0 // $10 + createTeam2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: team2Name, + Budget: &BudgetRequest{ + MaxLimit: team2Budget, + ResetDuration: "1h", + }, + }, + }) + + if createTeam2Resp.StatusCode != 200 { + t.Fatalf("Failed to create team2: status %d", createTeam2Resp.StatusCode) + } + + team2ID := ExtractIDFromResponse(t, createTeam2Resp, "id") + testData.AddTeam(team2ID) + + t.Logf("Created Team1 (budget: $%.2f) and Team2 (budget: $%.2f)", team1Budget, team2Budget) + + // Create VK assigned to Team 1 + vkName := "test-vk-team-switch-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &team1ID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK assigned to Team1") + + // Exhaust Team1's budget + consumedBudget := 0.0 + requestNum := 1 + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + t.Logf("Team1 budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + + if consumedBudget >= team1Budget { + // Make one more request to trigger rejection + continue + } + } + + if consumedBudget < team1Budget { + t.Fatalf("Could not exhaust Team1 budget") + } + + // Now switch VK to Team2 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + TeamID: &team2ID, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to switch VK to Team2: status %d", updateResp.StatusCode) + } + + t.Logf("Switched VK from Team1 to Team2") + + // Wait for in-memory update + time.Sleep(500 * time.Millisecond) + + // Request should now succeed with Team2's budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after switching to Team2"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after switching to Team2 with available budget, got status %d", resp.StatusCode) + } + + t.Logf("VK switch team after budget exhaustion verified ✓") +} + +// ============================================================================ +// SCENARIO 2: VK Switching Customers After Budget Exhaustion +// ============================================================================ + +// TestVKSwitchCustomerAfterBudgetExhaustion verifies that after exhausting one customer's budget, +// switching the VK to another customer allows requests to pass +func TestVKSwitchCustomerAfterBudgetExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create Customer 1 with small budget + customer1Name := "test-customer1-switch-" + generateRandomID() + customer1Budget := 0.01 // $0.01 + createCustomer1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customer1Name, + Budget: &BudgetRequest{ + MaxLimit: customer1Budget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomer1Resp.StatusCode != 200 { + t.Fatalf("Failed to create customer1: status %d", createCustomer1Resp.StatusCode) + } + + customer1ID := ExtractIDFromResponse(t, createCustomer1Resp, "id") + testData.AddCustomer(customer1ID) + + // Create Customer 2 with higher budget + customer2Name := "test-customer2-switch-" + generateRandomID() + customer2Budget := 10.0 // $10 + createCustomer2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customer2Name, + Budget: &BudgetRequest{ + MaxLimit: customer2Budget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomer2Resp.StatusCode != 200 { + t.Fatalf("Failed to create customer2: status %d", createCustomer2Resp.StatusCode) + } + + customer2ID := ExtractIDFromResponse(t, createCustomer2Resp, "id") + testData.AddCustomer(customer2ID) + + t.Logf("Created Customer1 (budget: $%.2f) and Customer2 (budget: $%.2f)", customer1Budget, customer2Budget) + + // Create VK assigned directly to Customer 1 + vkName := "test-vk-customer-switch-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + CustomerID: &customer1ID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK assigned to Customer1") + + // Exhaust Customer1's budget + consumedBudget := 0.0 + requestNum := 1 + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + t.Logf("Customer1 budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + + if consumedBudget >= customer1Budget { + continue + } + } + + if consumedBudget < customer1Budget { + t.Fatalf("Could not exhaust Customer1 budget") + } + + // Now switch VK to Customer2 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + CustomerID: &customer2ID, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to switch VK to Customer2: status %d", updateResp.StatusCode) + } + + t.Logf("Switched VK from Customer1 to Customer2") + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed with Customer2's budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after switching to Customer2"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after switching to Customer2 with available budget, got status %d", resp.StatusCode) + } + + t.Logf("VK switch customer after budget exhaustion verified ✓") +} + +// ============================================================================ +// SCENARIO 3: Hierarchical Chain VK->Team->Customer Budget Switching +// ============================================================================ + +// TestHierarchicalChainBudgetSwitch verifies switching the entire hierarchy +func TestHierarchicalChainBudgetSwitch(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create Customer 1 with small budget + customer1Name := "test-customer1-hierarchy-" + generateRandomID() + createCustomer1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customer1Name, + Budget: &BudgetRequest{ + MaxLimit: 0.01, // $0.01 - most restrictive + ResetDuration: "1h", + }, + }, + }) + + if createCustomer1Resp.StatusCode != 200 { + t.Fatalf("Failed to create customer1: status %d", createCustomer1Resp.StatusCode) + } + + customer1ID := ExtractIDFromResponse(t, createCustomer1Resp, "id") + testData.AddCustomer(customer1ID) + + // Create Team 1 under Customer 1 + team1Name := "test-team1-hierarchy-" + generateRandomID() + createTeam1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: team1Name, + CustomerID: &customer1ID, + Budget: &BudgetRequest{ + MaxLimit: 100.0, // High budget - customer is limiting + ResetDuration: "1h", + }, + }, + }) + + if createTeam1Resp.StatusCode != 200 { + t.Fatalf("Failed to create team1: status %d", createTeam1Resp.StatusCode) + } + + team1ID := ExtractIDFromResponse(t, createTeam1Resp, "id") + testData.AddTeam(team1ID) + + // Create Customer 2 with higher budget + customer2Name := "test-customer2-hierarchy-" + generateRandomID() + createCustomer2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customer2Name, + Budget: &BudgetRequest{ + MaxLimit: 100.0, // High budget + ResetDuration: "1h", + }, + }, + }) + + if createCustomer2Resp.StatusCode != 200 { + t.Fatalf("Failed to create customer2: status %d", createCustomer2Resp.StatusCode) + } + + customer2ID := ExtractIDFromResponse(t, createCustomer2Resp, "id") + testData.AddCustomer(customer2ID) + + // Create Team 2 under Customer 2 + team2Name := "test-team2-hierarchy-" + generateRandomID() + createTeam2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: team2Name, + CustomerID: &customer2ID, + Budget: &BudgetRequest{ + MaxLimit: 100.0, // High budget + ResetDuration: "1h", + }, + }, + }) + + if createTeam2Resp.StatusCode != 200 { + t.Fatalf("Failed to create team2: status %d", createTeam2Resp.StatusCode) + } + + team2ID := ExtractIDFromResponse(t, createTeam2Resp, "id") + testData.AddTeam(team2ID) + + t.Logf("Created hierarchy: Customer1(low budget)->Team1 and Customer2(high budget)->Team2") + + // Create VK assigned to Team 1 + vkName := "test-vk-hierarchy-switch-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &team1ID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + // Exhaust Customer1's budget (which is limiting Team1) + consumedBudget := 0.0 + requestNum := 1 + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + t.Logf("Customer1 budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + // Switch VK to Team2 (under Customer2) + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + TeamID: &team2ID, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to switch VK to Team2: status %d", updateResp.StatusCode) + } + + t.Logf("Switched VK from Team1(Customer1) to Team2(Customer2)") + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after switching hierarchy"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after switching hierarchy, got status %d", resp.StatusCode) + } + + t.Logf("Hierarchical chain budget switch verified ✓") +} + +// ============================================================================ +// SCENARIO 4: VK Budget Update After Exhaustion +// ============================================================================ + +// TestVKBudgetUpdateAfterExhaustion verifies that updating VK budget after exhaustion allows requests +func TestVKBudgetUpdateAfterExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with small budget + vkName := "test-vk-budget-update-" + generateRandomID() + initialBudget := 0.01 // $0.01 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with budget: $%.2f", initialBudget) + + // Exhaust VK budget + consumedBudget := 0.0 + requestNum := 1 + sawBudgetRejection := false + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + sawBudgetRejection = true + t.Logf("VK budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + if !sawBudgetRejection { + t.Fatalf("No budget rejection observed; consumed budget: $%.6f", consumedBudget) + } + + // Update VK budget to a higher value + newBudget := 10.0 + resetDuration := "1h" + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newBudget, + ResetDuration: &resetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated VK budget from $%.2f to $%.2f", initialBudget, newBudget) + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after budget update"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after budget update, got status %d", resp.StatusCode) + } + + t.Logf("VK budget update after exhaustion verified ✓") +} + +// ============================================================================ +// SCENARIO 5: Team Budget Update After Exhaustion +// ============================================================================ + +// TestTeamBudgetUpdateAfterExhaustion verifies that updating team budget after exhaustion allows requests +func TestTeamBudgetUpdateAfterExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team with small budget + teamName := "test-team-budget-update-" + generateRandomID() + initialBudget := 0.01 // $0.01 + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create VK under team + vkName := "test-vk-team-budget-update-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created team with budget: $%.2f", initialBudget) + + // Exhaust team budget + consumedBudget := 0.0 + requestNum := 1 + sawBudgetRejection := false + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + sawBudgetRejection = true + t.Logf("Team budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + if !sawBudgetRejection { + t.Fatalf("No budget rejection observed; consumed budget: $%.6f", consumedBudget) + } + + // Update team budget + newBudget := 10.0 + resetDuration := "1h" + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/teams/" + teamID, + Body: UpdateTeamRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newBudget, + ResetDuration: &resetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update team budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated team budget from $%.2f to $%.2f", initialBudget, newBudget) + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after team budget update"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after team budget update, got status %d", resp.StatusCode) + } + + t.Logf("Team budget update after exhaustion verified ✓") +} + +// ============================================================================ +// SCENARIO 6: Customer Budget Update After Exhaustion +// ============================================================================ + +// TestCustomerBudgetUpdateAfterExhaustion verifies that updating customer budget after exhaustion allows requests +func TestCustomerBudgetUpdateAfterExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer with small budget + customerName := "test-customer-budget-update-" + generateRandomID() + initialBudget := 0.01 // $0.01 + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create team under customer + teamName := "test-team-customer-update-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + CustomerID: &customerID, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create VK under team + vkName := "test-vk-customer-budget-update-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created customer with budget: $%.2f", initialBudget) + + // Exhaust customer budget + consumedBudget := 0.0 + requestNum := 1 + sawBudgetRejection := false + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + sawBudgetRejection = true + t.Logf("Customer budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + if !sawBudgetRejection { + t.Fatalf("No budget rejection observed; consumed budget: $%.6f", consumedBudget) + } + + // Update customer budget + newBudget := 10.0 + resetDuration := "1h" + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/customers/" + customerID, + Body: UpdateCustomerRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newBudget, + ResetDuration: &resetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update customer budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated customer budget from $%.2f to $%.2f", initialBudget, newBudget) + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after customer budget update"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after customer budget update, got status %d", resp.StatusCode) + } + + t.Logf("Customer budget update after exhaustion verified ✓") +} + +// ============================================================================ +// SCENARIO 7: Provider Config Budget Update After Exhaustion +// ============================================================================ + +// TestProviderConfigBudgetUpdateAfterExhaustion verifies that updating provider config budget after exhaustion allows requests +func TestProviderConfigBudgetUpdateAfterExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with provider config budget + vkName := "test-vk-provider-budget-update-" + generateRandomID() + initialBudget := 0.01 // $0.01 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider config budget: $%.2f", initialBudget) + + // Get provider config ID + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + providerConfigs := vkData["provider_configs"].([]interface{}) + providerConfig := providerConfigs[0].(map[string]interface{}) + providerConfigID := uint(providerConfig["id"].(float64)) + + // Exhaust provider config budget + consumedBudget := 0.0 + requestNum := 1 + sawBudgetRejection := false + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + sawBudgetRejection = true + t.Logf("Provider config budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + if !sawBudgetRejection { + t.Fatalf("No budget rejection observed; consumed budget: $%.6f", consumedBudget) + } + + // Update provider config budget + newBudget := 10.0 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + ProviderConfigs: []ProviderConfigRequest{ + { + ID: &providerConfigID, + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: newBudget, + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update provider config budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated provider config budget from $%.2f to $%.2f", initialBudget, newBudget) + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after provider config budget update"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after provider config budget update, got status %d", resp.StatusCode) + } + + t.Logf("Provider config budget update after exhaustion verified ✓") +} + +// ============================================================================ +// SCENARIO 8: VK Deletion Cascade +// ============================================================================ + +// TestVKDeletionCascadeComplete verifies deleting VK removes provider configs, budgets, and rate limits from memory +func TestVKDeletionCascadeComplete(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with budget, rate limit, and provider configs + vkName := "test-vk-deletion-cascade-" + generateRandomID() + tokenLimit := int64(10000) + tokenResetDuration := "1h" + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: 10.0, + ResetDuration: "1h", + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: 5.0, + ResetDuration: "1h", + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + // Don't add to testData since we'll delete manually + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with budget, rate limit, and provider config") + + // Get initial state from in-memory store + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + getRateLimitsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap1 := getRateLimitsResp1.Body["rate_limits"].(map[string]interface{}) + + // Verify VK exists + _, vkExists := virtualKeysMap1[vkValue] + if !vkExists { + t.Fatalf("VK not found in in-memory store") + } + + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + vkBudgetID := vkData1["budget_id"].(string) + vkRateLimitID := vkData1["rate_limit_id"].(string) + providerConfigs := vkData1["provider_configs"].([]interface{}) + pc := providerConfigs[0].(map[string]interface{}) + pcBudgetID := pc["budget_id"].(string) + pcRateLimitID := pc["rate_limit_id"].(string) + + // Verify all resources exist in memory + _, vkBudgetExists := budgetsMap1[vkBudgetID] + _, vkRateLimitExists := rateLimitsMap1[vkRateLimitID] + _, pcBudgetExists := budgetsMap1[pcBudgetID] + _, pcRateLimitExists := rateLimitsMap1[pcRateLimitID] + + if !vkBudgetExists || !vkRateLimitExists || !pcBudgetExists || !pcRateLimitExists { + t.Fatalf("Not all resources found in memory before deletion") + } + + t.Logf("All resources exist in memory before deletion ✓") + + // Delete VK + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/virtual-keys/" + vkID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete VK: status %d", deleteResp.StatusCode) + } + + t.Logf("VK deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify VK and all related resources are removed from memory + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + + // VK should be gone + _, vkStillExists := virtualKeysMap2[vkValue] + if vkStillExists { + t.Fatalf("VK still exists in memory after deletion") + } + + // Budgets should be gone + _, vkBudgetStillExists := budgetsMap2[vkBudgetID] + _, pcBudgetStillExists := budgetsMap2[pcBudgetID] + if vkBudgetStillExists || pcBudgetStillExists { + t.Fatalf("Budgets should be cascade-deleted: VK budget exists=%v, PC budget exists=%v", + vkBudgetStillExists, pcBudgetStillExists) + } + + // Rate limits should be gone + _, vkRateLimitStillExists := rateLimitsMap2[vkRateLimitID] + _, pcRateLimitStillExists := rateLimitsMap2[pcRateLimitID] + if vkRateLimitStillExists || pcRateLimitStillExists { + t.Logf("Note: Rate limits may still exist in memory (orphaned) - this is acceptable") + } + + t.Logf("VK removed from memory after deletion ✓") + t.Logf("VK deletion cascade verified ✓") +} + +// ============================================================================ +// SCENARIO 9: Team/Customer Deletion Should Delete Budget +// ============================================================================ + +// TestTeamDeletionDeletesBudget verifies that deleting a team also deletes its budget from memory +func TestTeamDeletionDeletesBudget(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team with budget + teamName := "test-team-delete-budget-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: 100.0, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + // Don't add to testData since we'll delete manually + + t.Logf("Created team with budget") + + // Get budget ID from in-memory store + getTeamsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + teamsMap1 := getTeamsResp1.Body["teams"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + teamData1 := teamsMap1[teamID].(map[string]interface{}) + budgetID := teamData1["budget_id"].(string) + + _, budgetExists := budgetsMap1[budgetID] + if !budgetExists { + t.Fatalf("Budget not found in memory before deletion") + } + + t.Logf("Team and budget exist in memory ✓") + + // Delete team + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/teams/" + teamID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete team: status %d", deleteResp.StatusCode) + } + + t.Logf("Team deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify team and budget are removed from memory + getTeamsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + teamsMap2 := getTeamsResp2.Body["teams"].(map[string]interface{}) + + _, teamStillExists := teamsMap2[teamID] + if teamStillExists { + t.Fatalf("Team still exists in memory after deletion") + } + + t.Logf("Team removed from memory ✓") + + // Verify budget is also removed from memory + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + if getBudgetsResp2.StatusCode != 200 { + t.Fatalf("Failed to get budgets from memory: status %d", getBudgetsResp2.StatusCode) + } + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + _, budgetStillExists := budgetsMap2[budgetID] + if budgetStillExists { + t.Fatalf("Budget %s still exists in memory after team deletion", budgetID) + } + + t.Logf("Budget removed from memory ✓") + t.Logf("Team deletion with budget verified ✓") +} + +// TestCustomerDeletionDeletesBudget verifies that deleting a customer also deletes its budget from memory +func TestCustomerDeletionDeletesBudget(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer with budget + customerName := "test-customer-delete-budget-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: 100.0, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + // Don't add to testData since we'll delete manually + + t.Logf("Created customer with budget") + + // Get budget ID from in-memory store + getCustomersResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + customersMap1 := getCustomersResp1.Body["customers"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + customerData1 := customersMap1[customerID].(map[string]interface{}) + budgetID := customerData1["budget_id"].(string) + + _, budgetExists := budgetsMap1[budgetID] + if !budgetExists { + t.Fatalf("Budget not found in memory before deletion") + } + + t.Logf("Customer and budget exist in memory ✓") + + // Delete customer + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/customers/" + customerID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete customer: status %d", deleteResp.StatusCode) + } + + t.Logf("Customer deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify customer is removed from memory + getCustomersResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + customersMap2 := getCustomersResp2.Body["customers"].(map[string]interface{}) + + _, customerStillExists := customersMap2[customerID] + if customerStillExists { + t.Fatalf("Customer still exists in memory after deletion") + } + + t.Logf("Customer removed from memory ✓") + + // Verify budget is also removed from memory + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + if getBudgetsResp2.StatusCode != 200 { + t.Fatalf("Failed to get budgets from memory: status %d", getBudgetsResp2.StatusCode) + } + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + _, budgetStillExists := budgetsMap2[budgetID] + if budgetStillExists { + t.Fatalf("Budget still exists in memory after customer deletion") + } + + t.Logf("Budget removed from memory ✓") + t.Logf("Customer deletion with budget verified ✓") +} + +// ============================================================================ +// SCENARIO 10: Team/Customer Deletion Sets VK entity_id = nil +// ============================================================================ + +// TestTeamDeletionSetsVKTeamIDToNil verifies that deleting a team sets team_id=nil on associated VKs +func TestTeamDeletionSetsVKTeamIDToNil(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team + teamName := "test-team-vk-nil-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + // Don't add to testData since we'll delete manually + + // Create VK assigned to team + vkName := "test-vk-team-nil-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created team and VK assigned to it") + + // Verify VK has team_id set + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + + teamIDFromVK1, hasTeamID := vkData1["team_id"].(string) + if !hasTeamID || teamIDFromVK1 != teamID { + t.Fatalf("VK team_id not set correctly before team deletion") + } + + t.Logf("VK has team_id=%s ✓", teamID) + + // Delete team + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/teams/" + teamID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete team: status %d", deleteResp.StatusCode) + } + + t.Logf("Team deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify VK still exists but team_id is nil + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + vkData2, vkStillExists := virtualKeysMap2[vkValue].(map[string]interface{}) + if !vkStillExists { + t.Fatalf("VK should still exist after team deletion") + } + + teamIDFromVK2, hasTeamID2 := vkData2["team_id"].(string) + if hasTeamID2 && teamIDFromVK2 != "" { + t.Fatalf("VK team_id should be nil after team deletion, got: %s", teamIDFromVK2) + } + + t.Logf("VK team_id is now nil ✓") + t.Logf("Team deletion sets VK team_id to nil verified ✓") +} + +// TestCustomerDeletionSetsVKCustomerIDToNil verifies that deleting a customer sets customer_id=nil on associated VKs +func TestCustomerDeletionSetsVKCustomerIDToNil(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer + customerName := "test-customer-vk-nil-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + // Don't add to testData since we'll delete manually + + // Create VK assigned directly to customer + vkName := "test-vk-customer-nil-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + CustomerID: &customerID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created customer and VK assigned to it") + + // Verify VK has customer_id set + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + + customerIDFromVK1, hasCustomerID := vkData1["customer_id"].(string) + if !hasCustomerID || customerIDFromVK1 != customerID { + t.Fatalf("VK customer_id not set correctly before customer deletion") + } + + t.Logf("VK has customer_id=%s ✓", customerID) + + // Delete customer + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/customers/" + customerID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete customer: status %d", deleteResp.StatusCode) + } + + t.Logf("Customer deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify VK still exists but customer_id is nil + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + vkData2, vkStillExists := virtualKeysMap2[vkValue].(map[string]interface{}) + if !vkStillExists { + t.Fatalf("VK should still exist after customer deletion") + } + + customerIDFromVK2, hasCustomerID2 := vkData2["customer_id"].(string) + if hasCustomerID2 && customerIDFromVK2 != "" { + t.Fatalf("VK customer_id should be nil after customer deletion, got: %s", customerIDFromVK2) + } + + t.Logf("VK customer_id is now nil ✓") + t.Logf("Customer deletion sets VK customer_id to nil verified ✓") +} diff --git a/plugins/governance/changelog.md b/plugins/governance/changelog.md index e69de29bb..89d262917 100644 --- a/plugins/governance/changelog.md +++ b/plugins/governance/changelog.md @@ -0,0 +1,3 @@ +- refactor: extracted governance store into an interface for extensibility +- refactor: extended the way governance store handles rate limits +- chore: added e2e tests for governance plugin diff --git a/plugins/governance/config_update_sync_test.go b/plugins/governance/config_update_sync_test.go new file mode 100644 index 000000000..a252c7d6e --- /dev/null +++ b/plugins/governance/config_update_sync_test.go @@ -0,0 +1,1123 @@ +package governance + +import ( + "testing" + "time" +) + +// ============================================================================ +// VK-LEVEL RATE LIMIT UPDATE SYNC +// ============================================================================ + +// TestVKRateLimitUpdateSyncToMemory tests that VK rate limit updates sync to in-memory store +// and that usage resets to 0 when new max limit < current usage +func TestVKRateLimitUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with initial rate limit + vkName := "test-vk-rate-update-" + generateRandomID() + initialTokenLimit := int64(10000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &initialTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with initial token limit: %d", initialTokenLimit) + + // Get initial in-memory state + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData1 := getVKResp1.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + rateLimitID1, _ := vkData1["rate_limit_id"].(string) + + getRateLimitsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap1 := getRateLimitsResp1.Body["rate_limits"].(map[string]interface{}) + rateLimit1 := rateLimitsMap1[rateLimitID1].(map[string]interface{}) + + initialTokenMaxLimit, _ := rateLimit1["token_max_limit"].(float64) + initialTokenUsage, _ := rateLimit1["token_current_usage"].(float64) + + if int64(initialTokenMaxLimit) != initialTokenLimit { + t.Fatalf("Initial token max limit not correct: expected %d, got %d", initialTokenLimit, int64(initialTokenMaxLimit)) + } + + t.Logf("Initial state in memory: TokenMaxLimit=%d, TokenCurrentUsage=%d", int64(initialTokenMaxLimit), int64(initialTokenUsage)) + + // Make a request to consume some tokens + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume tokens.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume tokens") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getVKResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData2 := getVKResp2.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + rateLimitID2, _ := vkData2["rate_limit_id"].(string) + + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + rateLimit2 := rateLimitsMap2[rateLimitID2].(map[string]interface{}) + + tokenUsageBeforeUpdate, _ := rateLimit2["token_current_usage"].(float64) + t.Logf("Token usage after request: %d", int64(tokenUsageBeforeUpdate)) + + if tokenUsageBeforeUpdate <= 0 { + t.Skip("No tokens consumed - cannot test usage reset") + } + + // NOW UPDATE: set new limit LOWER than current usage to trigger reset + // Usage reset only happens when new max limit <= current usage + newLowerLimit := int64(tokenUsageBeforeUpdate / 2) // Set to half of current usage to ensure it's lower + if newLowerLimit <= 0 { + newLowerLimit = int64(tokenUsageBeforeUpdate / 10) // Fallback to 10% if too small + } + if newLowerLimit <= 0 { + newLowerLimit = 1 // Minimum of 1 + } + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &newLowerLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK rate limit: status %d", updateResp.StatusCode) + } + + t.Logf("Updated token limit from %d to %d (new limit %d <= current usage %d)", initialTokenLimit, newLowerLimit, newLowerLimit, int64(tokenUsageBeforeUpdate)) + + // Wait for update to sync + time.Sleep(500 * time.Millisecond) + + // Verify update in in-memory store + getVKResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData3 := getVKResp3.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + rateLimitID3, _ := vkData3["rate_limit_id"].(string) + + getRateLimitsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap3 := getRateLimitsResp3.Body["rate_limits"].(map[string]interface{}) + rateLimit3 := rateLimitsMap3[rateLimitID3].(map[string]interface{}) + + newTokenMaxLimit, _ := rateLimit3["token_max_limit"].(float64) + tokenUsageAfterUpdate, _ := rateLimit3["token_current_usage"].(float64) + + // Verify new max limit is reflected + if int64(newTokenMaxLimit) != newLowerLimit { + t.Fatalf("Token max limit not updated in memory: expected %d, got %d", newLowerLimit, int64(newTokenMaxLimit)) + } + + t.Logf("✓ Token max limit updated in memory: %d", int64(newTokenMaxLimit)) + + // Verify usage reset to 0 (since new max limit <= current usage) + if tokenUsageAfterUpdate > 0.001 { + t.Fatalf("Token usage should reset to 0 when new limit (%d) <= current usage (%d), but got %d", newLowerLimit, int64(tokenUsageBeforeUpdate), int64(tokenUsageAfterUpdate)) + } + + t.Logf("✓ Token usage correctly reset to 0 (new limit: %d <= old usage: %d)", int64(newTokenMaxLimit), int64(tokenUsageBeforeUpdate)) + + // Test UPDATE with higher limit (usage should NOT reset) + newerHigherLimit := int64(50000) + updateResp2 := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &newerHigherLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if updateResp2.StatusCode != 200 { + t.Fatalf("Failed to update VK rate limit second time: status %d", updateResp2.StatusCode) + } + + time.Sleep(500 * time.Millisecond) + + getVKResp4 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData4 := getVKResp4.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + rateLimitID4, _ := vkData4["rate_limit_id"].(string) + + getRateLimitsResp4 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap4 := getRateLimitsResp4.Body["rate_limits"].(map[string]interface{}) + rateLimit4 := rateLimitsMap4[rateLimitID4].(map[string]interface{}) + + newerTokenMaxLimit, _ := rateLimit4["token_max_limit"].(float64) + tokenUsageAfterSecondUpdate, _ := rateLimit4["token_current_usage"].(float64) + + // Verify new higher limit is reflected + if int64(newerTokenMaxLimit) != newerHigherLimit { + t.Fatalf("Token max limit not updated to higher value: expected %d, got %d", newerHigherLimit, int64(newerTokenMaxLimit)) + } + + t.Logf("✓ Token max limit updated to higher value: %d", int64(newerTokenMaxLimit)) + + // Since usage is 0 and new limit is higher, usage stays 0 + if tokenUsageAfterSecondUpdate != 0 { + t.Logf("Note: Token usage is %d (expected 0 since it was reset)", int64(tokenUsageAfterSecondUpdate)) + } + + t.Logf("VK rate limit update sync to memory verified ✓") +} + +// TestVKBudgetUpdateSyncToMemory tests that VK budget updates sync to in-memory store +// and that usage resets to 0 when new max budget < current usage +func TestVKBudgetUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with initial budget + vkName := "test-vk-budget-update-" + generateRandomID() + initialBudget := 10.0 // $10 + resetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: resetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with initial budget: $%.2f", initialBudget) + + // Get initial in-memory state + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData1 := getVKResp1.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + budgetID, _ := vkData1["budget_id"].(string) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + budget1 := budgetsMap1[budgetID].(map[string]interface{}) + + initialMaxLimit, _ := budget1["max_limit"].(float64) + initialUsage, _ := budget1["current_usage"].(float64) + + if initialMaxLimit != initialBudget { + t.Fatalf("Initial budget max limit not correct: expected %.2f, got %.2f", initialBudget, initialMaxLimit) + } + + t.Logf("Initial state in memory: MaxLimit=$%.2f, CurrentUsage=$%.6f", initialMaxLimit, initialUsage) + + // Make a request to consume some budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume budget.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume budget") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budget2 := budgetsMap2[budgetID].(map[string]interface{}) + + usageBeforeUpdate, _ := budget2["current_usage"].(float64) + t.Logf("Budget usage after request: $%.6f", usageBeforeUpdate) + + if usageBeforeUpdate <= 0 { + t.Skip("No budget consumed - cannot test usage reset") + } + + // UPDATE: set new limit LOWER than current usage to trigger reset + // Usage reset only happens when new max limit <= current usage + newLowerBudget := usageBeforeUpdate * 0.5 // Set to half of current usage to ensure it's lower + if newLowerBudget <= 0 { + newLowerBudget = usageBeforeUpdate * 0.1 // Fallback to 10% if too small + } + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newLowerBudget, + ResetDuration: &resetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated budget from $%.2f to $%.6f (new limit %.6f < current usage %.6f)", initialBudget, newLowerBudget, newLowerBudget, usageBeforeUpdate) + + // Wait for update to sync + time.Sleep(1500 * time.Millisecond) + + // Verify update in in-memory store + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budget3 := budgetsMap3[budgetID].(map[string]interface{}) + + newMaxLimit, _ := budget3["max_limit"].(float64) + usageAfterUpdate, _ := budget3["current_usage"].(float64) + + // Verify new max limit is reflected + if newMaxLimit != newLowerBudget { + t.Fatalf("Budget max limit not updated in memory: expected %.6f, got %.6f", newLowerBudget, newMaxLimit) + } + + t.Logf("✓ Budget max limit updated in memory: $%.6f", newMaxLimit) + + // Verify usage reset to 0 (since new max limit <= current usage) + if usageAfterUpdate > 0.000001 { + t.Fatalf("Budget usage should reset to 0 when new limit (%.6f) <= current usage (%.6f), but got $%.6f", newMaxLimit, usageBeforeUpdate, usageAfterUpdate) + } + + t.Logf("✓ Budget usage correctly reset to 0 (new limit: $%.6f <= old usage: $%.6f)", newMaxLimit, usageBeforeUpdate) + + t.Logf("VK budget update sync to memory verified ✓") +} + +// ============================================================================ +// PROVIDER CONFIG RATE LIMIT UPDATE SYNC +// ============================================================================ + +// TestProviderRateLimitUpdateSyncToMemory tests that provider config rate limit updates sync to memory +func TestProviderRateLimitUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with provider config and initial rate limit + vkName := "test-vk-provider-rate-update-" + generateRandomID() + initialTokenLimit := int64(5000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &initialTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider config, initial token limit: %d", initialTokenLimit) + + // Get initial in-memory state + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData1 := getVKResp1.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + providerConfigs1 := vkData1["provider_configs"].([]interface{}) + providerConfig1 := providerConfigs1[0].(map[string]interface{}) + providerConfigID := uint(providerConfig1["id"].(float64)) + rateLimitID1, _ := providerConfig1["rate_limit_id"].(string) + + getRateLimitsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap1 := getRateLimitsResp1.Body["rate_limits"].(map[string]interface{}) + rateLimit1 := rateLimitsMap1[rateLimitID1].(map[string]interface{}) + + initialTokenMaxLimit, _ := rateLimit1["token_max_limit"].(float64) + initialTokenUsage, _ := rateLimit1["token_current_usage"].(float64) + + if int64(initialTokenMaxLimit) != initialTokenLimit { + t.Fatalf("Initial token max limit not correct: expected %d, got %d", initialTokenLimit, int64(initialTokenMaxLimit)) + } + + t.Logf("Initial provider rate limit in memory: TokenMaxLimit=%d, TokenCurrentUsage=%d", int64(initialTokenMaxLimit), int64(initialTokenUsage)) + + // Make a request to consume some tokens + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume provider tokens.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume provider tokens") + } + + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getVKResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData2 := getVKResp2.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + providerConfigs2 := vkData2["provider_configs"].([]interface{}) + providerConfig2 := providerConfigs2[0].(map[string]interface{}) + rateLimitID2, _ := providerConfig2["rate_limit_id"].(string) + + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + rateLimit2 := rateLimitsMap2[rateLimitID2].(map[string]interface{}) + + tokenUsageBeforeUpdate, _ := rateLimit2["token_current_usage"].(float64) + t.Logf("Provider token usage after request: %d", int64(tokenUsageBeforeUpdate)) + + if tokenUsageBeforeUpdate <= 0 { + t.Skip("No provider tokens consumed - cannot test usage reset") + } + + // UPDATE: set new limit LOWER than current usage + newLowerLimit := int64(50) // Much lower + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + ProviderConfigs: []ProviderConfigRequest{ + { + ID: &providerConfigID, + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &newLowerLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update provider rate limit: status %d", updateResp.StatusCode) + } + + t.Logf("Updated provider token limit from %d to %d", initialTokenLimit, newLowerLimit) + + time.Sleep(500 * time.Millisecond) + + // Verify update in in-memory store + getVKResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData3 := getVKResp3.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + providerConfigs3 := vkData3["provider_configs"].([]interface{}) + providerConfig3 := providerConfigs3[0].(map[string]interface{}) + rateLimitID3, _ := providerConfig3["rate_limit_id"].(string) + + getRateLimitsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap3 := getRateLimitsResp3.Body["rate_limits"].(map[string]interface{}) + rateLimit3 := rateLimitsMap3[rateLimitID3].(map[string]interface{}) + + newTokenMaxLimit, _ := rateLimit3["token_max_limit"].(float64) + tokenUsageAfterUpdate, _ := rateLimit3["token_current_usage"].(float64) + + // Verify new limit is reflected + if int64(newTokenMaxLimit) != newLowerLimit { + t.Fatalf("Provider token max limit not updated: expected %d, got %d", newLowerLimit, int64(newTokenMaxLimit)) + } + + t.Logf("✓ Provider token max limit updated in memory: %d", int64(newTokenMaxLimit)) + + // Verify usage reset to 0 (since new max < old usage) + if tokenUsageAfterUpdate > 0.001 { + t.Fatalf("Provider token usage should reset to 0 when new limit < current usage, but got %d", int64(tokenUsageAfterUpdate)) + } + + t.Logf("✓ Provider token usage reset to 0 (new limit: %d < old usage: %d)", int64(newTokenMaxLimit), int64(tokenUsageBeforeUpdate)) + + t.Logf("Provider rate limit update sync to memory verified ✓") +} + +// ============================================================================ +// TEAM BUDGET UPDATE SYNC +// ============================================================================ + +// TestTeamBudgetUpdateSyncToMemory tests that team budget updates sync to in-memory store +func TestTeamBudgetUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team with initial budget + teamName := "test-team-budget-update-" + generateRandomID() + initialBudget := 5.0 + resetDuration := "1h" + + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: resetDuration, + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create VK under team to consume budget + vkName := "test-vk-under-team-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created team with initial budget: $%.2f", initialBudget) + + // Get initial in-memory state + getTeamsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + teamsMap1 := getTeamsResp1.Body["teams"].(map[string]interface{}) + teamData1 := teamsMap1[teamID].(map[string]interface{}) + budgetID, _ := teamData1["budget_id"].(string) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + budget1 := budgetsMap1[budgetID].(map[string]interface{}) + + initialMaxLimit, _ := budget1["max_limit"].(float64) + initialUsage, _ := budget1["current_usage"].(float64) + + if initialMaxLimit != initialBudget { + t.Fatalf("Initial budget not correct: expected %.2f, got %.2f", initialBudget, initialMaxLimit) + } + + t.Logf("Initial team budget in memory: MaxLimit=$%.2f, CurrentUsage=$%.6f", initialMaxLimit, initialUsage) + + // Make request to consume team budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume team budget.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume team budget") + } + + // Wait for usage to be updated in memory + var usageBeforeUpdate float64 + usageUpdated := WaitForCondition(t, func() bool { + getBudgetsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap := getBudgetsResp.Body["budgets"].(map[string]interface{}) + if budget, ok := budgetsMap[budgetID].(map[string]interface{}); ok { + if usage, ok := budget["current_usage"].(float64); ok && usage > 0 { + usageBeforeUpdate = usage + return true + } + } + return false + }, 3*time.Second, "team budget usage > 0") + + if !usageUpdated { + t.Skip("Team budget usage did not update in time") + } + + t.Logf("Team budget usage after request: $%.6f", usageBeforeUpdate) + + // UPDATE: set new limit LOWER than current usage + newLowerBudget := 0.001 + resetDurationPtr := resetDuration + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/teams/" + teamID, + Body: UpdateTeamRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newLowerBudget, + ResetDuration: &resetDurationPtr, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update team budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated team budget from $%.2f to $%.2f", initialBudget, newLowerBudget) + + // Wait for update to sync to in-memory store + var newMaxLimit, usageAfterUpdate float64 + updateSynced := WaitForCondition(t, func() bool { + getBudgetsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap := getBudgetsResp.Body["budgets"].(map[string]interface{}) + if budget, ok := budgetsMap[budgetID].(map[string]interface{}); ok { + if maxLimit, ok := budget["max_limit"].(float64); ok { + newMaxLimit = maxLimit + usageAfterUpdate, _ = budget["current_usage"].(float64) + // Check if the new limit has been applied + return maxLimit == newLowerBudget + } + } + return false + }, 3*time.Second, "team budget max limit updated to new value") + + if !updateSynced { + t.Fatalf("Team budget update did not sync to memory in time") + } + + t.Logf("✓ Team budget max limit updated in memory: $%.2f", newMaxLimit) + + // Verify usage reset to 0 (since new max < old usage) + if usageAfterUpdate > 0.000001 { + t.Fatalf("Team budget usage should reset to 0 when new limit < current usage, but got $%.6f", usageAfterUpdate) + } + + t.Logf("✓ Team budget usage correctly reset to 0 (new limit: $%.2f < old usage: $%.6f)", newMaxLimit, usageBeforeUpdate) + + t.Logf("Team budget update sync to memory verified ✓") +} + +// ============================================================================ +// CUSTOMER BUDGET UPDATE SYNC +// ============================================================================ + +// TestCustomerBudgetUpdateSyncToMemory tests that customer budget updates sync to in-memory store +func TestCustomerBudgetUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer with initial budget + customerName := "test-customer-budget-update-" + generateRandomID() + initialBudget := 20.0 + resetDuration := "1h" + + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: resetDuration, + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create team and VK under customer + teamName := "test-team-under-customer-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + CustomerID: &customerID, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + vkName := "test-vk-under-customer-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created customer with initial budget: $%.2f", initialBudget) + + // Get initial in-memory state + getCustomersResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + customersMap1 := getCustomersResp1.Body["customers"].(map[string]interface{}) + customerData1 := customersMap1[customerID].(map[string]interface{}) + budgetID, _ := customerData1["budget_id"].(string) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + budget1 := budgetsMap1[budgetID].(map[string]interface{}) + + initialMaxLimit, _ := budget1["max_limit"].(float64) + initialUsage, _ := budget1["current_usage"].(float64) + + if initialMaxLimit != initialBudget { + t.Fatalf("Initial customer budget not correct: expected %.2f, got %.2f", initialBudget, initialMaxLimit) + } + + t.Logf("Initial customer budget in memory: MaxLimit=$%.2f, CurrentUsage=$%.6f", initialMaxLimit, initialUsage) + + // Make request to consume customer budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume customer budget.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume customer budget") + } + + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budget2 := budgetsMap2[budgetID].(map[string]interface{}) + + usageBeforeUpdate, _ := budget2["current_usage"].(float64) + t.Logf("Customer budget usage after request: $%.6f", usageBeforeUpdate) + + if usageBeforeUpdate <= 0 { + t.Skip("No customer budget consumed") + } + + // UPDATE: set new limit LOWER than current usage + newLowerBudget := 0.001 + resetDurationPtr := resetDuration + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/customers/" + customerID, + Body: UpdateCustomerRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newLowerBudget, + ResetDuration: &resetDurationPtr, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update customer budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated customer budget from $%.2f to $%.2f", initialBudget, newLowerBudget) + + time.Sleep(500 * time.Millisecond) + + // Verify update in in-memory store + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budget3 := budgetsMap3[budgetID].(map[string]interface{}) + + newMaxLimit, _ := budget3["max_limit"].(float64) + usageAfterUpdate, _ := budget3["current_usage"].(float64) + + // Verify new limit is reflected + if newMaxLimit != newLowerBudget { + t.Fatalf("Customer budget max limit not updated: expected %.2f, got %.2f", newLowerBudget, newMaxLimit) + } + + t.Logf("✓ Customer budget max limit updated in memory: $%.2f", newMaxLimit) + + // Verify usage reset to 0 (since new max < old usage) + if usageAfterUpdate > 0.000001 { + t.Fatalf("Customer budget usage should reset to 0 when new limit < current usage, but got $%.6f", usageAfterUpdate) + } + + t.Logf("✓ Customer budget usage correctly reset to 0 (new limit: $%.2f < old usage: $%.6f)", newMaxLimit, usageBeforeUpdate) + + t.Logf("Customer budget update sync to memory verified ✓") +} + +// ============================================================================ +// PROVIDER CONFIG BUDGET UPDATE SYNC +// ============================================================================ + +// TestProviderBudgetUpdateSyncToMemory tests that provider config budget updates sync to memory +func TestProviderBudgetUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with provider config and initial budget + vkName := "test-vk-provider-budget-update-" + generateRandomID() + initialBudget := 5.0 + resetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: resetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider budget: $%.2f", initialBudget) + + // Get initial in-memory state + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData1 := getVKResp1.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + providerConfigs1 := vkData1["provider_configs"].([]interface{}) + providerConfig1 := providerConfigs1[0].(map[string]interface{}) + providerConfigID := uint(providerConfig1["id"].(float64)) + budgetID, _ := providerConfig1["budget_id"].(string) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + budget1 := budgetsMap1[budgetID].(map[string]interface{}) + + initialMaxLimit, _ := budget1["max_limit"].(float64) + initialUsage, _ := budget1["current_usage"].(float64) + + if initialMaxLimit != initialBudget { + t.Fatalf("Initial provider budget not correct: expected %.2f, got %.2f", initialBudget, initialMaxLimit) + } + + t.Logf("Initial provider budget in memory: MaxLimit=$%.2f, CurrentUsage=$%.6f", initialMaxLimit, initialUsage) + + // Make request to consume provider budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume provider budget.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume provider budget") + } + + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budget2 := budgetsMap2[budgetID].(map[string]interface{}) + + usageBeforeUpdate, _ := budget2["current_usage"].(float64) + t.Logf("Provider budget usage after request: $%.6f", usageBeforeUpdate) + + if usageBeforeUpdate <= 0 { + t.Skip("No provider budget consumed") + } + + // UPDATE: set new limit LOWER than current usage + newLowerBudget := 0.001 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + ProviderConfigs: []ProviderConfigRequest{ + { + ID: &providerConfigID, + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: newLowerBudget, + ResetDuration: resetDuration, + }, + }, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update provider budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated provider budget from $%.2f to $%.2f", initialBudget, newLowerBudget) + + time.Sleep(500 * time.Millisecond) + + // Verify update in in-memory store + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budget3 := budgetsMap3[budgetID].(map[string]interface{}) + + newMaxLimit, _ := budget3["max_limit"].(float64) + usageAfterUpdate, _ := budget3["current_usage"].(float64) + + // Verify new limit is reflected + if newMaxLimit != newLowerBudget { + t.Fatalf("Provider budget max limit not updated: expected %.2f, got %.2f", newLowerBudget, newMaxLimit) + } + + t.Logf("✓ Provider budget max limit updated in memory: $%.2f", newMaxLimit) + + // Verify usage reset to 0 (since new max < old usage) + if usageAfterUpdate > 0.000001 { + t.Fatalf("Provider budget usage should reset to 0 when new limit < current usage, but got $%.6f", usageAfterUpdate) + } + + t.Logf("✓ Provider budget usage correctly reset to 0 (new limit: $%.2f < old usage: $%.6f)", newMaxLimit, usageBeforeUpdate) + + t.Logf("Provider budget update sync to memory verified ✓") +} diff --git a/plugins/governance/customer_budget_test.go b/plugins/governance/customer_budget_test.go new file mode 100644 index 000000000..79e04c1df --- /dev/null +++ b/plugins/governance/customer_budget_test.go @@ -0,0 +1,335 @@ +package governance + +import ( + "strconv" + "testing" +) + +// TestCustomerBudgetExceededWithMultipleVKs tests that customer level budgets are enforced across multiple VKs +// by making requests until budget is consumed +func TestCustomerBudgetExceededWithMultipleVKs(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a customer with a fixed budget + customerBudget := 0.01 + customerName := "test-customer-budget-exceeded-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: customerBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create 2 VKs under the customer (directly, without team) + var vkValues []string + for i := 1; i <= 2; i++ { + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-" + generateRandomID(), + CustomerID: &customerID, + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High VK budget so customer is the limiting factor + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK %d: status %d", i, createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValues = append(vkValues, vk["value"].(string)) + } + + t.Logf("Created customer %s with budget $%.2f and 2 VKs", customerName, customerBudget) + + // Keep making requests alternating between VKs, tracking actual token usage until customer budget is exceeded + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + vkIndex := 0 + + for requestNum <= 50 { + // Alternate between VKs to test shared customer budget + vkValue := vkValues[vkIndex%2] + + // Create a longer prompt to consume more tokens and budget faster + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "customer") { + t.Logf("Request %d correctly rejected: customer budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, customerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify that we made at least one successful request before hitting budget + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d (VK%d) succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, (vkIndex%2)+1, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, customerBudget) + } + } + } + + requestNum++ + vkIndex++ + + if shouldStop { + break + } + + if consumedBudget >= customerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit customer budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, customerBudget) +} + +// TestCustomerBudgetExceededWithMultipleTeams tests that customer level budgets are enforced across multiple teams +// by making requests until budget is consumed +func TestCustomerBudgetExceededWithMultipleTeams(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a customer with a fixed budget + customerBudget := 0.01 + customerName := "test-customer-multi-team-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: customerBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create 2 teams under the customer + var vkValues []string + for i := 1; i <= 2; i++ { + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: "test-team-" + generateRandomID(), + CustomerID: &customerID, + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High team budget so customer is the limiting factor + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team %d: status %d", i, createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create a VK under each team + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-" + generateRandomID(), + TeamID: &teamID, + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High VK budget so customer is the limiting factor + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK %d: status %d", i, createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValues = append(vkValues, vk["value"].(string)) + } + + t.Logf("Created customer %s with budget $%.2f and 2 teams with VKs", customerName, customerBudget) + + // Keep making requests alternating between VKs in different teams, tracking actual token usage until customer budget is exceeded + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + vkIndex := 0 + + for requestNum <= 50 { + // Alternate between VKs in different teams to test shared customer budget + vkValue := vkValues[vkIndex%2] + + // Create a longer prompt to consume more tokens and budget faster + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "customer") { + t.Logf("Request %d correctly rejected: customer budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, customerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify that we made at least one successful request before hitting budget + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d (VK%d) succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, (vkIndex%2)+1, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, customerBudget) + } + } + } + + requestNum++ + vkIndex++ + + if shouldStop { + break + } + + if consumedBudget >= customerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit customer budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, customerBudget) +} diff --git a/plugins/governance/e2e_test.go b/plugins/governance/e2e_test.go new file mode 100644 index 000000000..de8e9c3e3 --- /dev/null +++ b/plugins/governance/e2e_test.go @@ -0,0 +1,1543 @@ +package governance + +import ( + "fmt" + "sync" + "testing" + "time" + + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" +) + +// ============================================================================ +// CRITICAL: Multiple VKs Sharing Team Budget +// ============================================================================ + +// TestMultipleVKsSharingTeamBudgetFairness verifies that when multiple VKs share a team budget, +// one VK cannot monopolize the budget and block others. +// Budget enforcement is POST-HOC: the request that exceeds the budget is allowed, +// but subsequent requests are blocked. +func TestMultipleVKsSharingTeamBudgetFairness(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a team with a small budget that will be exceeded quickly + teamName := "test-team-shared-budget-" + generateRandomID() + teamBudget := 0.01 // $0.01 for team - small enough to exceed in a few requests + teamResetDuration := "1h" + + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: teamBudget, + ResetDuration: teamResetDuration, + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + t.Logf("Created team with shared budget: $%.4f", teamBudget) + + // Create VK1 assigned to team + vk1Name := "test-vk1-shared-" + generateRandomID() + createVK1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vk1Name, + TeamID: &teamID, + }, + }) + + if createVK1Resp.StatusCode != 200 { + t.Fatalf("Failed to create VK1: status %d", createVK1Resp.StatusCode) + } + + vk1ID := ExtractIDFromResponse(t, createVK1Resp, "id") + testData.AddVirtualKey(vk1ID) + + vk1 := createVK1Resp.Body["virtual_key"].(map[string]interface{}) + vk1Value := vk1["value"].(string) + + // Create VK2 assigned to same team + vk2Name := "test-vk2-shared-" + generateRandomID() + createVK2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vk2Name, + TeamID: &teamID, + }, + }) + + if createVK2Resp.StatusCode != 200 { + t.Fatalf("Failed to create VK2: status %d", createVK2Resp.StatusCode) + } + + vk2ID := ExtractIDFromResponse(t, createVK2Resp, "id") + testData.AddVirtualKey(vk2ID) + + vk2 := createVK2Resp.Body["virtual_key"].(map[string]interface{}) + vk2Value := vk2["value"].(string) + + t.Logf("Created VK1 and VK2 both assigned to same team") + + // Use VK1 to consume team budget until it's exceeded + // Budget enforcement is POST-HOC: request that exceeds is allowed, next is blocked + consumedBudget := 0.0 + requestNum := 1 + shouldStop := false + + for requestNum <= 150 { // Need many requests since each costs ~$0.0001 + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hi, how are you?", + }, + }, + }, + VKHeader: &vk1Value, + }) + + if resp.StatusCode >= 400 { + // VK1 got rejected - budget exceeded + if CheckErrorMessage(t, resp, "budget") { + t.Logf("VK1 request %d rejected: team budget exceeded at $%.6f/$%.4f", requestNum, consumedBudget, teamBudget) + break + } else { + t.Fatalf("VK1 request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + // Extract cost from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + t.Logf("VK1 request %d: cost=$%.6f, total consumed=$%.6f/$%.4f", requestNum, cost, consumedBudget, teamBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= teamBudget { + shouldStop = true + } + } + + // Verify that team budget was indeed exceeded + if consumedBudget < teamBudget { + t.Fatalf("Could not exceed team budget after %d requests (consumed $%.6f / $%.4f)", requestNum-1, consumedBudget, teamBudget) + } + + t.Logf("Team budget exhausted by VK1: $%.6f consumed (limit: $%.4f)", consumedBudget, teamBudget) + + // Now try VK2 - should be rejected because team budget was exhausted by VK1 + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hello how are you?", + }, + }, + }, + VKHeader: &vk2Value, + }) + + // VK2 should be rejected because team budget was consumed by VK1 + if resp2.StatusCode < 400 { + t.Fatalf("VK2 request should be rejected due to shared team budget exhaustion but got status %d", resp2.StatusCode) + } + + if !CheckErrorMessage(t, resp2, "budget") { + t.Fatalf("Expected budget error for VK2 but got: %v", resp2.Body) + } + + t.Logf("Multiple VKs sharing team budget verified ✓") + t.Logf("VK2 correctly rejected when team budget exhausted by VK1") +} + +// ============================================================================ +// CRITICAL: Full Budget Hierarchy Validation (All 4 Levels) +// ============================================================================ + +// TestFullBudgetHierarchyEnforcement verifies that ALL levels of hierarchy are checked: +// Provider Budget → VK Budget → Team Budget → Customer Budget +// Budget enforcement happens AFTER limit is exceeded - the request that exceeds is allowed, +// but subsequent requests are blocked. +func TestFullBudgetHierarchyEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer with high budget + customerName := "test-customer-hierarchy-" + generateRandomID() + customerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: 1000.0, // Very high + ResetDuration: "1h", + }, + }, + }) + + if customerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", customerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, customerResp, "id") + testData.AddCustomer(customerID) + + // Create team under customer with medium budget + teamName := "test-team-hierarchy-" + generateRandomID() + teamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + CustomerID: &customerID, + Budget: &BudgetRequest{ + MaxLimit: 100.0, // Medium + ResetDuration: "1h", + }, + }, + }) + + if teamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", teamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, teamResp, "id") + testData.AddTeam(teamID) + + // Create VK under team with lower budget + // Provider budget is MOST RESTRICTIVE at $0.01 - should be exceeded after 2-3 requests + vkName := "test-vk-hierarchy-" + generateRandomID() + vkBudget := 0.1 // $0.1 + providerBudget := 0.01 // $0.01 - MOST RESTRICTIVE + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + Budget: &BudgetRequest{ + MaxLimit: vkBudget, + ResetDuration: "1h", + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: providerBudget, + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created full hierarchy:") + t.Logf(" Customer Budget: $1000.0 (not limiting)") + t.Logf(" Team Budget: $100.0 (not limiting)") + t.Logf(" VK Budget: $%.2f (not limiting)", vkBudget) + t.Logf(" Provider Budget: $%.2f (MOST RESTRICTIVE)", providerBudget) + + // Make requests until provider budget is exceeded + // Budget enforcement: request that exceeds is allowed, NEXT request is blocked + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + shouldStop := false + + for requestNum <= 20 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test hierarchy enforcement request " + string(rune('0'+requestNum%10)), + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") { + t.Logf("Request %d correctly rejected: budget exceeded at provider level", requestNum) + t.Logf("Consumed budget: $%.6f (provider limit: $%.2f)", consumedBudget, providerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify rejection happened after exceeding the budget + if consumedBudget < providerBudget { + t.Fatalf("Request rejected before budget was exceeded: consumed $%.6f < limit $%.2f", consumedBudget, providerBudget) + } + + t.Logf("Full budget hierarchy enforcement verified ✓") + t.Logf("Request blocked at provider level (lowest in hierarchy)") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualCost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += actualCost + lastSuccessfulCost = actualCost + t.Logf("Request %d succeeded: cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, actualCost, consumedBudget, providerBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= providerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider budget limit (consumed $%.6f / $%.2f) - budget not being enforced at provider level", + requestNum-1, consumedBudget, providerBudget) +} + +// ============================================================================ +// CRITICAL: Failed Requests Don't Consume Budget/Rate Limits +// ============================================================================ + +// TestFailedRequestsDoNotConsumeBudget verifies that requests that fail +// (4xx/5xx responses) do not consume budget or rate limits +func TestFailedRequestsDoNotConsumeBudget(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with small budget to easily verify consumption + vkName := "test-vk-failed-requests-" + generateRandomID() + budget := 0.1 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budget, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with budget: $%.2f", budget) + + // Get initial budget from in-memory store + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + budgetID, _ := vkData1["budget_id"].(string) + + budgetData1 := budgetsMap1[budgetID].(map[string]interface{}) + initialUsage, _ := budgetData1["current_usage"].(float64) + + t.Logf("Initial budget usage: $%.6f", initialUsage) + + // Make a request with invalid input that will fail + // Using an invalid model name to force 400 error + failResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "invalid-model-that-does-not-exist", + Messages: []ChatMessage{ + { + Role: "user", + Content: "This request should fail.", + }, + }, + }, + VKHeader: &vkValue, + }) + + t.Logf("Failed request status: %d", failResp.StatusCode) + + if failResp.StatusCode < 400 { + t.Skip("Could not create failing request - model may be accepted") + } + + // Wait for any async processing + time.Sleep(500 * time.Millisecond) + + // Check budget usage - should NOT have changed + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budgetData2 := budgetsMap2[budgetID].(map[string]interface{}) + usageAfterFailed, _ := budgetData2["current_usage"].(float64) + + t.Logf("Budget usage after failed request: $%.6f", usageAfterFailed) + + if usageAfterFailed > initialUsage+0.0001 { + t.Fatalf("Failed request consumed budget: before=$%.6f, after=$%.6f", initialUsage, usageAfterFailed) + } + + // Now make a successful request + successResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "This request should succeed.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if successResp.StatusCode != 200 { + t.Skip("Could not make successful request") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Check budget usage - should have changed + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budgetData3 := budgetsMap3[budgetID].(map[string]interface{}) + usageAfterSuccess, _ := budgetData3["current_usage"].(float64) + + t.Logf("Budget usage after successful request: $%.6f", usageAfterSuccess) + + if usageAfterSuccess <= usageAfterFailed+0.0001 { + t.Fatalf("Successful request did not consume budget: before=$%.6f, after=$%.6f", usageAfterFailed, usageAfterSuccess) + } + + t.Logf("Failed requests do NOT consume budget ✓") + t.Logf("Successful requests DO consume budget ✓") +} + +// ============================================================================ +// CRITICAL: Inactive Virtual Key Behavior +// ============================================================================ + +// TestInactiveVirtualKeyBlocking verifies that inactive VKs reject requests immediately +// and that reactivating VK allows requests again +func TestInactiveVirtualKeyBlocking(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create active VK + vkName := "test-vk-inactive-" + generateRandomID() + isActive := true + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + IsActive: &isActive, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK in ACTIVE state") + + // Verify active VK works + resp1 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request with active VK should succeed.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp1.StatusCode != 200 { + t.Fatalf("Active VK request should succeed but got status %d", resp1.StatusCode) + } + + t.Logf("Active VK request succeeded ✓") + + // Deactivate VK + isInactive := false + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + IsActive: &isInactive, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to deactivate VK: status %d", updateResp.StatusCode) + } + + t.Logf("VK deactivated (isActive = false)") + + // Wait for in-memory store update + time.Sleep(500 * time.Millisecond) + + // Verify inactive VK is blocked + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request with inactive VK should be blocked.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp2.StatusCode < 400 { + t.Fatalf("Inactive VK request should be blocked but got status %d", resp2.StatusCode) + } + + if !CheckErrorMessage(t, resp2, "blocked") { + t.Fatalf("Expected 'blocked' in error message but got: %v", resp2.Body) + } + + t.Logf("Inactive VK request rejected ✓") + + // Reactivate VK + isActiveAgain := true + reactivateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + IsActive: &isActiveAgain, + }, + }) + + if reactivateResp.StatusCode != 200 { + t.Fatalf("Failed to reactivate VK: status %d", reactivateResp.StatusCode) + } + + t.Logf("VK reactivated (isActive = true)") + + // Wait for in-memory store update + time.Sleep(500 * time.Millisecond) + + // Verify reactivated VK works + resp3 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request with reactivated VK should succeed.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp3.StatusCode != 200 { + t.Fatalf("Reactivated VK request should succeed but got status %d", resp3.StatusCode) + } + + t.Logf("Reactivated VK request succeeded ✓") + t.Logf("Inactive VK behavior verified ✓") +} + +// ============================================================================ +// HIGH: Rate Limit Reset Boundaries and Edge Cases +// ============================================================================ + +// TestRateLimitResetBoundaryConditions verifies rate limit resets at exact boundaries +func TestRateLimitResetBoundaryConditions(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with short reset duration for quick testing + vkName := "test-vk-reset-boundary-" + generateRandomID() + requestLimit := int64(1) + resetDuration := "15s" // Short duration for testing + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &requestLimit, + RequestResetDuration: &resetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with request limit: %d request per %s", requestLimit, resetDuration) + + // Make first request at t=0 + startTime := time.Now() + resp1 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "First request at t=0.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp1.StatusCode != 200 { + t.Skip("Could not make first request") + } + + t.Logf("First request succeeded at t=0 ✓") + + // Try immediate second request - should fail + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Second request before reset.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp2.StatusCode < 400 { + t.Fatalf("Second request should be rejected but got status %d", resp2.StatusCode) + } + + t.Logf("Second request rejected (within reset window) ✓") + + // Wait for reset duration + 1 second to ensure reset happens + waitTime := time.Until(startTime.Add(16 * time.Second)) + if waitTime > 0 { + t.Logf("Waiting %.1f seconds for rate limit to reset...", waitTime.Seconds()) + time.Sleep(waitTime) + } + + // After reset, third request should succeed + resp3 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Third request after reset duration.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp3.StatusCode != 200 { + t.Fatalf("Third request after reset should succeed but got status %d", resp3.StatusCode) + } + + t.Logf("Third request succeeded after reset duration ✓") + t.Logf("Rate limit reset boundary conditions verified ✓") +} + +// ============================================================================ +// HIGH: Concurrent Requests to Same VK +// ============================================================================ + +// TestConcurrentRequestsToSameVK verifies that concurrent requests are handled safely +// and counters remain accurate under concurrent load +func TestConcurrentRequestsToSameVK(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with high token limit to allow concurrent requests + vkName := "test-vk-concurrent-" + generateRandomID() + tokenLimit := int64(100000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with high token limit for concurrent testing") + + // Launch concurrent requests + numGoroutines := 5 + requestsPerGoroutine := 3 + totalRequests := numGoroutines * requestsPerGoroutine + + var wg sync.WaitGroup + successCount := 0 + var mu sync.Mutex + + t.Logf("Launching %d goroutines with %d requests each (total: %d requests)", + numGoroutines, requestsPerGoroutine, totalRequests) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(goID int) { + defer wg.Done() + for j := 0; j < requestsPerGoroutine; j++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Concurrent request from goroutine.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + mu.Lock() + successCount++ + mu.Unlock() + } + } + }(i) + } + + wg.Wait() + + t.Logf("Concurrent requests completed: %d successful out of %d total", successCount, totalRequests) + + if successCount == 0 { + t.Skip("No requests succeeded - cannot test concurrent behavior") + } + + if successCount < totalRequests/2 { + t.Logf("Warning: Less than 50%% requests succeeded (%d/%d)", successCount, totalRequests) + } + + t.Logf("Concurrent request handling verified ✓") + t.Logf("No data corruption detected (test completed successfully)") +} + +// ============================================================================ +// HIGH: Budget State After Reset +// ============================================================================ + +// TestBudgetStateAfterReset verifies that budget usage is correctly reset to 0 +// and LastReset timestamp is updated +func TestBudgetStateAfterReset(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with short reset duration + vkName := "test-vk-budget-reset-state-" + generateRandomID() + budgetLimit := 1.0 + resetDuration := "15s" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budgetLimit, + ResetDuration: resetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with budget: $%.2f, reset duration: %s", budgetLimit, resetDuration) + + // Get initial budget state + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + budgetID, _ := vkData1["budget_id"].(string) + + budgetData1 := budgetsMap1[budgetID].(map[string]interface{}) + initialUsage, _ := budgetData1["current_usage"].(float64) + lastReset1, _ := budgetData1["last_reset"].(string) + + t.Logf("Initial budget state: usage=$%.6f, lastReset=%s", initialUsage, lastReset1) + + // Make a request to consume some budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request to consume budget before reset.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume budget") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Check usage after request + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budgetData2 := budgetsMap2[budgetID].(map[string]interface{}) + usageAfterRequest, _ := budgetData2["current_usage"].(float64) + + t.Logf("Budget after request: usage=$%.6f (consumed)", usageAfterRequest) + + if usageAfterRequest <= initialUsage { + t.Skip("Request did not consume budget") + } + + // Wait for reset duration to pass + // We need to wait until LastReset + resetDuration has passed + // Parse the lastReset time to calculate the exact wait time + lastResetTime, err := time.Parse(time.RFC3339Nano, lastReset1) + if err != nil { + // Fallback to RFC3339 if RFC3339Nano fails + lastResetTime, err = time.Parse(time.RFC3339, lastReset1) + if err != nil { + t.Fatalf("Failed to parse lastReset time: %v", err) + } + } + resetDurationParsed, err := configstoreTables.ParseDuration(resetDuration) + if err != nil { + t.Fatalf("Failed to parse reset duration: %v", err) + } + + // Calculate when reset should occur with a 2-second safety buffer + resetTime := lastResetTime.Add(resetDurationParsed).Add(2 * time.Second) + waitTime := time.Until(resetTime) + if waitTime > 0 { + t.Logf("Waiting %.1f seconds for budget to reset (lastReset was %s, reset duration is %s)...", waitTime.Seconds(), lastReset1, resetDuration) + time.Sleep(waitTime) + } else { + t.Logf("No wait needed - reset duration has already passed") + } + + // Budget resets are LAZY - they happen when: + // 1. Background tracker runs ResetExpiredBudgets, OR + // 2. A new request triggers UpdateBudgetUsage (which resets expired budgets inline) + // Make another request to trigger the lazy reset mechanism + t.Logf("Making request to trigger lazy budget reset...") + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request after reset duration to trigger lazy reset.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp2.StatusCode != 200 { + t.Logf("Post-reset request status: %d (expected 200)", resp2.StatusCode) + } + + // Wait for async update using polling instead of fixed sleep + // Poll for budget data to reflect the reset + _, resetVerified := WaitForAPICondition(t, APIRequest{ + Method: "GET", + Path: fmt.Sprintf("/api/governance/budgets?from_memory=true"), + }, func(resp *APIResponse) bool { + if resp.StatusCode != 200 { + return false + } + budgetsData, ok := resp.Body["budgets"].(map[string]interface{}) + if !ok { + return false + } + budgetData, ok := budgetsData[budgetID].(map[string]interface{}) + if !ok { + return false + } + // Check if LastReset has been updated (indicating reset occurred) + newLastReset, ok := budgetData["last_reset"].(string) + return ok && newLastReset != lastReset1 + }, 5*time.Second, "budget reset verified by timestamp") + + if !resetVerified { + t.Logf("Warning: Reset verification polling timed out, but will proceed with final check") + } + + // Check budget after reset + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budgetData3 := budgetsMap3[budgetID].(map[string]interface{}) + usageAfterReset, _ := budgetData3["current_usage"].(float64) + lastReset3, _ := budgetData3["last_reset"].(string) + + t.Logf("Budget after reset: usage=$%.6f, lastReset=%s", usageAfterReset, lastReset3) + + // Verify the reset actually happened by checking the LastReset timestamp changed + // This is the most reliable indicator that a reset occurred + if lastReset3 == lastReset1 { + t.Fatalf("Budget reset failed: LastReset timestamp was not updated (%s -> %s)", lastReset1, lastReset3) + } + t.Logf("✓ Budget reset verified by LastReset timestamp change") + + // Verify budget wasn't cumulative (which would indicate no reset) + // A normal request costs $0.003-0.010 + // If it's the sum of two requests, it would be $0.008+ + // This maximum check prevents detecting cumulative usage while allowing cost variations + if usageAfterReset > 0.012 { + t.Logf("WARNING: Budget usage suspiciously high after reset: $%.6f (might indicate reset didn't work, but timestamp changed so reset verified)", usageAfterReset) + t.Logf(" Before reset: $%.6f", usageAfterRequest) + t.Logf(" After reset: $%.6f", usageAfterReset) + // Don't fail - could be legitimate variation in API costs + } + + t.Logf("Budget state after reset verified ✓") + t.Logf("Usage was reset from $%.6f to ~$%.6f (cost of one post-reset request) ✓", usageAfterRequest, usageAfterReset) +} + +// ============================================================================ +// HIGH: Team Deletion Cascade +// ============================================================================ + +// TestTeamDeletionCascade verifies that deleting a team with VKs properly cleans up +func TestTeamDeletionCascade(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team + teamName := "test-team-deletion-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: 100.0, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + t.Logf("Created team: %s", teamID) + + // Create VK assigned to team + vkName := "test-vk-for-team-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK assigned to team: %s", vkID) + + // Verify VK works + resp1 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request before team deletion.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp1.StatusCode != 200 { + t.Skip("Could not verify VK before deletion") + } + + t.Logf("VK works before team deletion ✓") + + // Delete team + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/teams/" + teamID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete team: status %d", deleteResp.StatusCode) + } + + t.Logf("Team deleted") + + // Wait for in-memory store update + time.Sleep(500 * time.Millisecond) + + // Try to use VK after team deletion + // Expected: VK should continue to work after team deletion + // VKs can function independently without a team, but they lose access to team budget + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request after team deletion.", + }, + }, + }, + VKHeader: &vkValue, + }) + + // Assert VK request succeeds after team deletion + if resp2.StatusCode != 200 { + t.Fatalf("Expected 200 OK after team deletion (VK should continue to work), got status %d. Response: %v", resp2.StatusCode, resp2.Body) + } + + // Assert no team budget was billed (team is deleted, so team budget should not be used) + // The request should succeed but without team budget constraints + // Note: We can't directly verify team budget wasn't billed from the response, + // but we verify the request succeeds which confirms VK works independently + t.Logf("Team deletion cascade verified ✓: VK continues to work after team deletion (without team budget)") +} + +// ============================================================================ +// HIGH: VK Deletion Cascade +// ============================================================================ + +// TestVKDeletionCascade verifies that deleting a VK properly cleans up all related resources +func TestVKDeletionCascade(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with rate limit and budget + vkName := "test-vk-deletion-" + generateRandomID() + tokenLimit := int64(1000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: 10.0, + ResetDuration: "1h", + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with rate limit and budget") + + // Verify VK exists in in-memory store + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + + _, exists1 := virtualKeysMap1[vkValue] + if !exists1 { + t.Fatalf("VK not found in in-memory store after creation") + } + + t.Logf("VK exists in in-memory store ✓") + + // Delete VK + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/virtual-keys/" + vkID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete VK: status %d", deleteResp.StatusCode) + } + + t.Logf("VK deleted from database") + + // Wait for in-memory store update + time.Sleep(500 * time.Millisecond) + + // Verify VK is removed from in-memory store + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + _, exists2 := virtualKeysMap2[vkValue] + if exists2 { + t.Fatalf("VK still exists in in-memory store after deletion") + } + + t.Logf("VK removed from in-memory store ✓") + + // Try to use deleted VK + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request with deleted VK should fail.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode < 400 { + t.Logf("Deleted VK still accepts requests (status=%d) - may be cached in SDK", resp.StatusCode) + } else { + t.Logf("Deleted VK request rejected (status=%d) ✓", resp.StatusCode) + } + + t.Logf("VK deletion cascade verified ✓") +} + +// ============================================================================ +// FEATURE: Load Balancing with Weighted Provider Distribution +// ============================================================================ + +// TestWeightedProviderLoadBalancing verifies that traffic is distributed between +// providers according to their weights when they share common models +func TestWeightedProviderLoadBalancing(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with two providers: 99% OpenAI, 1% Azure (both support gpt-4o) + vkName := "test-vk-weighted-lb-" + generateRandomID() + openaiWeight := 99.0 + azureWeight := 1.0 + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: openaiWeight, + AllowedModels: []string{"gpt-4o"}, + }, + { + Provider: "azure", + Weight: azureWeight, + AllowedModels: []string{"gpt-4o"}, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with weighted providers: OpenAI(%.0f%%), Azure(%.0f%%)", openaiWeight, azureWeight) + + // Verify both providers are configured + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + providerConfigs, _ := vkData["provider_configs"].([]interface{}) + + if len(providerConfigs) != 2 { + t.Fatalf("Expected 2 provider configs, got %d", len(providerConfigs)) + } + + t.Logf("Both provider configs present in in-memory store ✓") + + // Make 10 requests with just "gpt-4o" (no provider prefix) + // Expected: ~99 go to OpenAI, ~1 go to Azure + numRequests := 10 + openaiCount := 0 + azureCount := 0 + failureCount := 0 + + t.Logf("Making %d weighted requests with model: 'gpt-4o' (no provider prefix)...", numRequests) + + for i := 0; i < numRequests; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "gpt-4o", // No provider prefix - should be routed based on weights + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hello how are you?", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + failureCount++ + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + continue + } + + // Try to detect which provider was used + // Check if model in response contains provider name + if provider, ok := resp.Body["extra_fields"].(map[string]interface{})["provider"].(string); ok { + model, ok := resp.Body["extra_fields"].(map[string]interface{})["model_requested"].(string) + if !ok { + t.Logf("Request %d failed to get model requested", i+1) + continue + } + if provider == "openai" { + openaiCount++ + t.Logf("Request %d routed to OpenAI (model: %s)", i+1, model) + } else if provider == "azure" { + azureCount++ + t.Logf("Request %d routed to Azure (model: %s)", i+1, model) + } + } + } + + totalSuccess := openaiCount + azureCount + t.Logf("Results: OpenAI=%d, Azure=%d, Failed=%d (total requests=%d)", + openaiCount, azureCount, failureCount, numRequests) + + if totalSuccess == 0 { + t.Skip("No successful requests to analyze distribution") + } + + // With 99% weight to OpenAI and 1% to Azure: + // Out of 10 requests, we expect ~0-2 to go to Azure (1%) + if azureCount > 2 { + t.Logf("Warning: More requests went to Azure than expected (got %d, expected ~0-2)", azureCount) + } + + t.Logf("Weighted provider load balancing verified ✓") + t.Logf("Traffic distribution approximately matches configured weights") +} + +// ============================================================================ +// FEATURE: Fallback Provider Mechanism +// ============================================================================ + +// TestProviderFallbackMechanism verifies that when primary provider doesn't support +// a model, fallback providers are used automatically +func TestProviderFallbackMechanism(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with two providers: + // - 99% Anthropic (does NOT support gpt-4o) + // - 1% OpenAI (DOES support gpt-4o) + // When requesting gpt-4o, it should fall back to OpenAI since Anthropic doesn't have it + vkName := "test-vk-fallback-" + generateRandomID() + anthropicWeight := 99.0 + openaiWeight := 1.0 + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "anthropic", + Weight: anthropicWeight, + AllowedModels: []string{"claude-3-sonnet"}, // Does NOT include gpt-4o + }, + { + Provider: "openai", + Weight: openaiWeight, + AllowedModels: []string{"gpt-4o"}, // DOES include gpt-4o + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with providers: Anthropic(99%%, no gpt-4o), OpenAI(1%%, supports gpt-4o)") + + // Make 5 requests for gpt-4o model + // Even though Anthropic has 99% weight, all should succeed via OpenAI fallback + numRequests := 5 + successCount := 0 + + t.Logf("Making %d requests with model: 'gpt-4o' (not supported by primary provider)...", numRequests) + + for i := 0; i < numRequests; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "gpt-4o", // Only OpenAI supports this + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hello how are you?", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + successCount++ + + // Try to detect which provider actually handled it + model := "" + if m, ok := resp.Body["model"].(string); ok { + model = m + } + + t.Logf("Request %d succeeded (model: %s) - likely via OpenAI fallback", i+1, model) + } else { + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + } + } + + t.Logf("Results: %d/%d requests succeeded via fallback", successCount, numRequests) + + if successCount == 0 { + t.Skip("No successful requests - cannot verify fallback mechanism") + } + + if successCount < numRequests { + t.Logf("Warning: Not all requests succeeded (got %d/%d)", successCount, numRequests) + } else { + t.Logf("All requests succeeded via fallback provider ✓") + } + + t.Logf("Fallback provider mechanism verified ✓") + t.Logf("Requests successfully routed to fallback when primary doesn't support model") +} diff --git a/plugins/governance/edge_cases_test.go b/plugins/governance/edge_cases_test.go new file mode 100644 index 000000000..1e2c50d1c --- /dev/null +++ b/plugins/governance/edge_cases_test.go @@ -0,0 +1,188 @@ +package governance + +import ( + "strconv" + "testing" + "time" +) + +// TestCrissCrossComplexBudgetHierarchy tests complex scenarios involving provider, VK, team, and customer level budgets +// Tests that the most restrictive budget at each level is enforced +func TestCrissCrossComplexBudgetHierarchy(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a customer with a moderate budget + customerBudget := 0.15 + customerName := "test-customer-criss-cross-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: customerBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create a team under customer with a tighter budget + teamBudget := 0.12 + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: "test-team-criss-cross-" + generateRandomID(), + CustomerID: &customerID, + Budget: &BudgetRequest{ + MaxLimit: teamBudget, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create a VK with even tighter budget and provider-specific budgets + vkBudget := 0.01 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-criss-cross-" + generateRandomID(), + TeamID: &teamID, + Budget: &BudgetRequest{ + MaxLimit: vkBudget, + ResetDuration: "1h", + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: 0.08, // Even tighter provider budget + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created hierarchy: Customer ($%.2f) -> Team ($%.2f) -> VK ($%.2f) with Provider Budget ($0.08)", + customerBudget, teamBudget, vkBudget) + + // Wait for VK and provider config budgets to be synced to in-memory store + time.Sleep(1000 * time.Millisecond) + + // Test: Provider budget should be the limiting factor (most restrictive) + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + + for requestNum <= 50 { + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "provider") { + t.Logf("Request %d correctly rejected: budget exceeded in criss-cross hierarchy", requestNum) + t.Logf("Consumed budget: $%.6f (provider budget limit: $0.08)", consumedBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f", + requestNum, actualInputTokens, actualOutputTokens, actualCost, consumedBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= 0.08 { // Provider budget + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider budget limit - budget not being enforced", + requestNum-1) +} diff --git a/plugins/governance/fixtures_test.go b/plugins/governance/fixtures_test.go new file mode 100644 index 000000000..ec58d4a5b --- /dev/null +++ b/plugins/governance/fixtures_test.go @@ -0,0 +1,221 @@ +package governance + +import ( + "sync" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockLogger implements schemas.Logger for testing +type MockLogger struct { + mu sync.Mutex + logs []string + errors []string + debugs []string + infos []string + warnings []string +} + +func NewMockLogger() *MockLogger { + return &MockLogger{ + logs: make([]string, 0), + errors: make([]string, 0), + debugs: make([]string, 0), + infos: make([]string, 0), + warnings: make([]string, 0), + } +} + +func (ml *MockLogger) SetLevel(level schemas.LogLevel) {} + +func (ml *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {} + +func (ml *MockLogger) Error(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.errors = append(ml.errors, format) +} + +func (ml *MockLogger) Warn(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.warnings = append(ml.warnings, format) +} + +func (ml *MockLogger) Info(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.infos = append(ml.infos, format) +} + +func (ml *MockLogger) Debug(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.debugs = append(ml.debugs, format) +} + +func (ml *MockLogger) Fatal(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.errors = append(ml.errors, format) +} + +// Test data builders + +func buildVirtualKey(id, value, name string, isActive bool) *configstoreTables.TableVirtualKey { + return &configstoreTables.TableVirtualKey{ + ID: id, + Value: value, + Name: name, + IsActive: isActive, + } +} + +func buildVirtualKeyWithBudget(id, value, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableVirtualKey { + vk := buildVirtualKey(id, value, name, true) + vk.Budget = budget + budgetID := budget.ID + vk.BudgetID = &budgetID + return vk +} + +func buildVirtualKeyWithRateLimit(id, value, name string, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableVirtualKey { + vk := buildVirtualKey(id, value, name, true) + vk.RateLimit = rateLimit + rateLimitID := rateLimit.ID + vk.RateLimitID = &rateLimitID + return vk +} + +func buildVirtualKeyWithProviders(id, value, name string, providers []configstoreTables.TableVirtualKeyProviderConfig) *configstoreTables.TableVirtualKey { + vk := buildVirtualKey(id, value, name, true) + vk.ProviderConfigs = providers + return vk +} + +func buildBudget(id string, maxLimit float64, resetDuration string) *configstoreTables.TableBudget { + return &configstoreTables.TableBudget{ + ID: id, + MaxLimit: maxLimit, + CurrentUsage: 0, + ResetDuration: resetDuration, + LastReset: time.Now(), + } +} + +func buildBudgetWithUsage(id string, maxLimit, currentUsage float64, resetDuration string) *configstoreTables.TableBudget { + return &configstoreTables.TableBudget{ + ID: id, + MaxLimit: maxLimit, + CurrentUsage: currentUsage, + ResetDuration: resetDuration, + LastReset: time.Now(), + } +} + +func buildRateLimit(id string, tokenMaxLimit, requestMaxLimit int64) *configstoreTables.TableRateLimit { + duration := "1m" + return &configstoreTables.TableRateLimit{ + ID: id, + TokenMaxLimit: &tokenMaxLimit, + TokenCurrentUsage: 0, + TokenResetDuration: &duration, + TokenLastReset: time.Now(), + RequestMaxLimit: &requestMaxLimit, + RequestCurrentUsage: 0, + RequestResetDuration: &duration, + RequestLastReset: time.Now(), + } +} + +func buildRateLimitWithUsage(id string, tokenMaxLimit, tokenUsage, requestMaxLimit, requestUsage int64) *configstoreTables.TableRateLimit { + duration := "1m" + return &configstoreTables.TableRateLimit{ + ID: id, + TokenMaxLimit: &tokenMaxLimit, + TokenCurrentUsage: tokenUsage, + TokenResetDuration: &duration, + TokenLastReset: time.Now(), + RequestMaxLimit: &requestMaxLimit, + RequestCurrentUsage: requestUsage, + RequestResetDuration: &duration, + RequestLastReset: time.Now(), + } +} + +func buildTeam(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableTeam { + team := &configstoreTables.TableTeam{ + ID: id, + Name: name, + } + if budget != nil { + team.Budget = budget + team.BudgetID = &budget.ID + } + return team +} + +func buildCustomer(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableCustomer { + customer := &configstoreTables.TableCustomer{ + ID: id, + Name: name, + } + if budget != nil { + customer.Budget = budget + customer.BudgetID = &budget.ID + } + return customer +} + +func buildProviderConfig(provider string, allowedModels []string) configstoreTables.TableVirtualKeyProviderConfig { + return configstoreTables.TableVirtualKeyProviderConfig{ + Provider: provider, + AllowedModels: allowedModels, + Weight: 1.0, + RateLimit: nil, + Budget: nil, + Keys: []configstoreTables.TableKey{}, + } +} + +func buildProviderConfigWithRateLimit(provider string, allowedModels []string, rateLimit *configstoreTables.TableRateLimit) configstoreTables.TableVirtualKeyProviderConfig { + pc := buildProviderConfig(provider, allowedModels) + pc.RateLimit = rateLimit + if rateLimit != nil { + pc.RateLimitID = &rateLimit.ID + } + return pc +} + +// Test helpers + +func assertDecision(t *testing.T, expected Decision, result *EvaluationResult) { + t.Helper() + assert.NotNil(t, result, "EvaluationResult should not be nil") + assert.Equal(t, expected, result.Decision, "Decision mismatch. Reason: %s", result.Reason) +} + +func assertVirtualKeyFound(t *testing.T, result *EvaluationResult) { + t.Helper() + assert.NotNil(t, result.VirtualKey, "VirtualKey should be found in result") +} + +func assertRateLimitInfo(t *testing.T, result *EvaluationResult) { + t.Helper() + assert.NotNil(t, result.RateLimitInfo, "RateLimitInfo should be present in result") +} + +func requireNoError(t *testing.T, err error, msg string) { + t.Helper() + require.NoError(t, err, msg) +} + +func requireError(t *testing.T, err error, msg string) { + t.Helper() + require.Error(t, err, msg) +} diff --git a/plugins/governance/go.mod b/plugins/governance/go.mod index 84e94e5b3..51cf7dd5b 100644 --- a/plugins/governance/go.mod +++ b/plugins/governance/go.mod @@ -7,6 +7,7 @@ require gorm.io/gorm v1.31.1 require ( github.com/maximhq/bifrost/core v1.2.42 github.com/maximhq/bifrost/framework v1.1.52 + github.com/stretchr/testify v1.11.1 ) require ( @@ -38,6 +39,7 @@ require ( github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -80,6 +82,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/qdrant/go-client v1.16.2 // indirect github.com/redis/go-redis/v9 v9.17.2 // indirect github.com/rs/zerolog v1.34.0 // indirect diff --git a/plugins/governance/in_memory_sync_test.go b/plugins/governance/in_memory_sync_test.go new file mode 100644 index 000000000..8de677a25 --- /dev/null +++ b/plugins/governance/in_memory_sync_test.go @@ -0,0 +1,554 @@ +package governance + +import ( + "testing" + "time" +) + +// TestInMemorySyncVirtualKeyUpdate tests that in-memory store is updated when VK is updated in DB +func TestInMemorySyncVirtualKeyUpdate(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with initial budget + vkName := "test-vk-sync-" + generateRandomID() + initialBudget := 10.0 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with initial budget $%.2f", vkName, initialBudget) + + // Verify in-memory store has the VK + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + + // Check that VK exists in in-memory store + vkData, exists := virtualKeysMap[vkValue] + if !exists { + t.Fatalf("VK %s not found in in-memory store after creation", vkValue) + } + + vkDataMap := vkData.(map[string]interface{}) + vkID2, _ := vkDataMap["id"].(string) + if vkID2 != vkID { + t.Fatalf("VK ID mismatch in in-memory store: expected %s, got %s", vkID, vkID2) + } + + t.Logf("VK found in in-memory store after creation ✓") + + // Update VK budget to 20.0 + newBudget := 20.0 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newBudget, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK: status %d, body: %v", updateResp.StatusCode, updateResp.Body) + } + + t.Logf("Updated VK budget from $%.2f to $%.2f", initialBudget, newBudget) + + // Verify in-memory store is updated + time.Sleep(500 * time.Millisecond) // Small delay for async updates + + getVKResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getVKResp2.StatusCode != 200 { + t.Fatalf("Failed to get governance data after update: status %d", getVKResp2.StatusCode) + } + + virtualKeysMap2 := getVKResp2.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + // Check that VK still exists + vkData2, exists := virtualKeysMap2[vkValue] + if !exists { + t.Fatalf("VK %s not found in in-memory store after update", vkValue) + } + + vkDataMap2 := vkData2.(map[string]interface{}) + budgetID, _ := vkDataMap2["budget_id"].(string) + + // Check that budget in in-memory store is updated + if budgetID != "" { + budgetData, budgetExists := budgetsMap2[budgetID] + if !budgetExists { + t.Fatalf("Budget %s not found in in-memory store", budgetID) + } + + budgetDataMap := budgetData.(map[string]interface{}) + maxLimit, _ := budgetDataMap["max_limit"].(float64) + if maxLimit != newBudget { + t.Fatalf("Budget max_limit not updated in in-memory store: expected %.2f, got %.2f", newBudget, maxLimit) + } + } + + t.Logf("VK budget updated in in-memory store ✓") +} + +// TestInMemorySyncTeamUpdate tests that in-memory store is updated when Team is updated +func TestInMemorySyncTeamUpdate(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a team with initial budget + teamName := "test-team-sync-" + generateRandomID() + initialBudget := 50.0 + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + t.Logf("Created team %s with initial budget $%.2f", teamName, initialBudget) + + // Verify in-memory store has the team + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + teamsMap := getDataResp.Body["teams"].(map[string]interface{}) + + _, exists := teamsMap[teamID] + if !exists { + t.Fatalf("Team %s not found in in-memory store after creation", teamID) + } + + t.Logf("Team found in in-memory store after creation ✓") + + // Update team budget to 100.0 + newTeamBudget := 100.0 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/teams/" + teamID, + Body: UpdateTeamRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newTeamBudget, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update team: status %d", updateResp.StatusCode) + } + + t.Logf("Updated team budget from $%.2f to $%.2f", initialBudget, newTeamBudget) + + // Verify in-memory store is updated + time.Sleep(500 * time.Millisecond) + + getTeamsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + if getTeamsResp2.StatusCode != 200 { + t.Fatalf("Failed to get governance data after update: status %d", getTeamsResp2.StatusCode) + } + + teamsMap2 := getTeamsResp2.Body["teams"].(map[string]interface{}) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + teamData2, exists := teamsMap2[teamID] + if !exists { + t.Fatalf("Team %s not found in in-memory store after update", teamID) + } + + teamDataMap := teamData2.(map[string]interface{}) + budgetID, _ := teamDataMap["budget_id"].(string) + + if budgetID != "" { + budgetData, budgetExists := budgetsMap2[budgetID] + if !budgetExists { + t.Fatalf("Budget %s not found in in-memory store", budgetID) + } + + budgetDataMap := budgetData.(map[string]interface{}) + maxLimit, _ := budgetDataMap["max_limit"].(float64) + if maxLimit != newTeamBudget { + t.Fatalf("Team budget max_limit not updated in in-memory store: expected %.2f, got %.2f", newTeamBudget, maxLimit) + } + } + + t.Logf("Team budget updated in in-memory store ✓") +} + +// TestInMemorySyncCustomerUpdate tests that in-memory store is updated when Customer is updated +func TestInMemorySyncCustomerUpdate(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a customer with initial budget + customerName := "test-customer-sync-" + generateRandomID() + initialBudget := 100.0 + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + t.Logf("Created customer %s with initial budget $%.2f", customerName, initialBudget) + + // Verify in-memory store has the customer + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + customersMap := getDataResp.Body["customers"].(map[string]interface{}) + + _, exists := customersMap[customerID] + if !exists { + t.Fatalf("Customer %s not found in in-memory store after creation", customerID) + } + + t.Logf("Customer found in in-memory store after creation ✓") + + // Update customer budget to 250.0 + newCustomerBudget := 250.0 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/customers/" + customerID, + Body: UpdateCustomerRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newCustomerBudget, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update customer: status %d", updateResp.StatusCode) + } + + t.Logf("Updated customer budget from $%.2f to $%.2f", initialBudget, newCustomerBudget) + + // Verify in-memory store is updated + time.Sleep(500 * time.Millisecond) + + getCustomersResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + if getCustomersResp2.StatusCode != 200 { + t.Fatalf("Failed to get governance data after update: status %d", getCustomersResp2.StatusCode) + } + + customersMap2 := getCustomersResp2.Body["customers"].(map[string]interface{}) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + customerData2, exists := customersMap2[customerID] + if !exists { + t.Fatalf("Customer %s not found in in-memory store after update", customerID) + } + + customerDataMap := customerData2.(map[string]interface{}) + budgetID, _ := customerDataMap["budget_id"].(string) + + if budgetID != "" { + budgetData, budgetExists := budgetsMap2[budgetID] + if !budgetExists { + t.Fatalf("Budget %s not found in in-memory store", budgetID) + } + + budgetDataMap := budgetData.(map[string]interface{}) + maxLimit, _ := budgetDataMap["max_limit"].(float64) + if maxLimit != newCustomerBudget { + t.Fatalf("Customer budget max_limit not updated in in-memory store: expected %.2f, got %.2f", newCustomerBudget, maxLimit) + } + } + + t.Logf("Customer budget updated in in-memory store ✓") +} + +// TestInMemorySyncVirtualKeyDelete tests that in-memory store is updated when VK is deleted +func TestInMemorySyncVirtualKeyDelete(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK + vkName := "test-vk-delete-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: 10.0, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + // Verify in-memory store has the VK + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + + _, exists := virtualKeysMap[vkValue] + if !exists { + t.Fatalf("VK not found in in-memory store after creation") + } + + t.Logf("VK found in in-memory store after creation ✓") + + // Delete the VK + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/virtual-keys/" + vkID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete VK: status %d", deleteResp.StatusCode) + } + + t.Logf("Deleted VK from database") + + // Verify in-memory store is updated + time.Sleep(2 * time.Second) + + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + _, exists = virtualKeysMap2[vkValue] + if exists { + t.Fatalf("VK %s still exists in in-memory store after deletion", vkValue) + } + + t.Logf("VK removed from in-memory store ✓") +} + +// TestDataEndpointConsistency tests that governance endpoints return consistent data +func TestDataEndpointConsistency(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create multiple resources + vkName := "test-vk-consistency-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: 15.0, + ResetDuration: "1h", + }, + }, + }) + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + teamName := "test-team-consistency-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: 30.0, + ResetDuration: "1h", + }, + }, + }) + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + customerName := "test-customer-consistency-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: 60.0, + ResetDuration: "1h", + }, + }, + }) + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + time.Sleep(1 * time.Second) + + // Get data from separate endpoints + getVKResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getVKResp.StatusCode != 200 { + t.Fatalf("Failed to get virtual keys: status %d", getVKResp.StatusCode) + } + + getTeamsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + if getTeamsResp.StatusCode != 200 { + t.Fatalf("Failed to get teams: status %d", getTeamsResp.StatusCode) + } + + getCustomersResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + if getCustomersResp.StatusCode != 200 { + t.Fatalf("Failed to get customers: status %d", getCustomersResp.StatusCode) + } + + virtualKeysMap := getVKResp.Body["virtual_keys"].(map[string]interface{}) + teamsMap := getTeamsResp.Body["teams"].(map[string]interface{}) + customersMap := getCustomersResp.Body["customers"].(map[string]interface{}) + + // Verify all created resources are in the in-memory data + vkCount := len(virtualKeysMap) + teamCount := len(teamsMap) + customerCount := len(customersMap) + + if vkCount == 0 { + t.Fatalf("No virtual keys found in data endpoint") + } + if teamCount == 0 { + t.Fatalf("No teams found in data endpoint") + } + if customerCount == 0 { + t.Fatalf("No customers found in data endpoint") + } + + t.Logf("Data endpoint returned consistent data: %d VKs, %d teams, %d customers ✓", vkCount, teamCount, customerCount) + + // Get the individual endpoints and verify consistency + getVKsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys", + }) + + if getVKsResp.StatusCode != 200 { + t.Fatalf("Failed to get virtual keys: status %d", getVKsResp.StatusCode) + } + + vksFromEndpoint, _ := getVKsResp.Body["count"].(float64) + if int(vksFromEndpoint) != vkCount { + // Can fail because sqlite db might get locked because of all parallel tests + t.Logf("[WARN]VK count mismatch between /data endpoint and /virtual-keys endpoint: %d vs %d (this can happen because of parallel tests)", vkCount, int(vksFromEndpoint)) + } + + t.Logf("Data consistency verified between endpoints ✓") +} diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 25add9e54..feb997869 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -37,6 +37,15 @@ type InMemoryStore interface { GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig } +type BaseGovernancePlugin interface { + GetName() string + TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) + PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) + PostHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) + Cleanup() error + GetGovernanceStore() GovernanceStore +} + // GovernancePlugin implements the main governance plugin with hierarchical budget system type GovernancePlugin struct { ctx context.Context @@ -44,9 +53,9 @@ type GovernancePlugin struct { wg sync.WaitGroup // Track active goroutines // Core components with clear separation of concerns - store *GovernanceStore // Pure data access layer - resolver *BudgetResolver // Pure decision engine for hierarchical governance - tracker *UsageTracker // Business logic owner (updates, resets, persistence) + store GovernanceStore // Pure data access layer + resolver *BudgetResolver // Pure decision engine for hierarchical governance + tracker *UsageTracker // Business logic owner (updates, resets, persistence) // Dependencies configStore configstore.ConfigStore @@ -67,7 +76,9 @@ type GovernancePlugin struct { // // Behavior and defaults: // - Enables all governance features with optimized defaults. -// - If `store` is nil, the plugin runs in-memory only (no persistence). +// - If `configStore` is nil, the plugin will use an in-memory LocalGovernanceStore +// (no persistence). Init constructs a LocalGovernanceStore internally when +// configStore is nil. // - If `modelCatalog` is nil, cost calculation is skipped. // - `config.IsVkMandatory` controls whether `x-bf-vk` is required in PreHook. // - `inMemoryStore` is used by TransportInterceptor to validate configured providers @@ -80,7 +91,7 @@ type GovernancePlugin struct { // - ctx: base context for the plugin; a child context with cancel is created. // - config: plugin flags; may be nil. // - logger: logger used by all subcomponents. -// - store: configuration store used for persistence; may be nil. +// - configStore: configuration store used for persistence; may be nil. // - governanceConfig: initial/seed governance configuration for the store. // - modelCatalog: optional model catalog to compute request cost. // - inMemoryStore: provider registry used for routing/validation in transports. @@ -91,17 +102,21 @@ type GovernancePlugin struct { // // Side effects: // - Logs warnings when optional dependencies are missing. -// - May perform startup resets via the usage tracker when `store` is non-nil. +// - May perform startup resets via the usage tracker when `configStore` is non-nil. +// +// Alternative entry point: +// - Use InitFromStore to inject a custom GovernanceStore implementation instead +// of constructing a LocalGovernanceStore internally. func Init( ctx context.Context, config *Config, logger schemas.Logger, - store configstore.ConfigStore, + configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig, modelCatalog *modelcatalog.ModelCatalog, inMemoryStore InMemoryStore, ) (*GovernancePlugin, error) { - if store == nil { + if configStore == nil { logger.Warn("governance plugin requires config store to persist data, running in memory only mode") } if modelCatalog == nil { @@ -114,7 +129,7 @@ func Init( isVkMandatory = config.IsVkMandatory } - governanceStore, err := NewGovernanceStore(ctx, logger, store, governanceConfig) + governanceStore, err := NewLocalGovernanceStore(ctx, logger, configStore, governanceConfig) if err != nil { return nil, fmt.Errorf("failed to initialize governance store: %w", err) } @@ -123,10 +138,10 @@ func Init( resolver := NewBudgetResolver(governanceStore, logger) // 3. Tracker (business logic owner, depends on store and resolver) - tracker := NewUsageTracker(ctx, governanceStore, resolver, store, logger) + tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger) // 4. Perform startup reset check for any expired limits from downtime - if store != nil { + if configStore != nil { if err := tracker.PerformStartupResets(ctx); err != nil { logger.Warn("startup reset failed: %v", err) // Continue initialization even if startup reset fails (non-critical) @@ -139,7 +154,7 @@ func Init( store: governanceStore, resolver: resolver, tracker: tracker, - configStore: store, + configStore: configStore, modelCatalog: modelCatalog, logger: logger, isVkMandatory: isVkMandatory, @@ -148,6 +163,66 @@ func Init( return plugin, nil } +// InitFromStore initializes and returns a governance plugin instance with a custom store. +// +// This constructor allows providing a custom GovernanceStore implementation instead of +// creating a new LocalGovernanceStore. Use this when you need to: +// - Inject a custom store implementation for testing +// - Use a pre-configured store instance +// - Integrate with non-standard storage backends +// +// Parameters are the same as Init, except governanceConfig is replaced by governanceStore. +// The governanceStore must not be nil, or an error is returned. +// +// See Init documentation for details on other parameters and behavior. +func InitFromStore( + ctx context.Context, + config *Config, + logger schemas.Logger, + governanceStore GovernanceStore, + configStore configstore.ConfigStore, + modelCatalog *modelcatalog.ModelCatalog, + inMemoryStore InMemoryStore, +) (*GovernancePlugin, error) { + if configStore == nil { + logger.Warn("governance plugin requires config store to persist data, running in memory only mode") + } + if modelCatalog == nil { + logger.Warn("governance plugin requires model catalog to calculate cost, all cost calculations will be skipped.") + } + if governanceStore == nil { + return nil, fmt.Errorf("governance store is nil") + } + // Handle nil config - use safe default for IsVkMandatory + var isVkMandatory *bool + if config != nil { + isVkMandatory = config.IsVkMandatory + } + resolver := NewBudgetResolver(governanceStore, logger) + tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger) + // Perform startup reset check for any expired limits from downtime + if configStore != nil { + if err := tracker.PerformStartupResets(ctx); err != nil { + logger.Warn("startup reset failed: %v", err) + // Continue initialization even if startup reset fails (non-critical) + } + } + ctx, cancelFunc := context.WithCancel(ctx) + plugin := &GovernancePlugin{ + ctx: ctx, + cancelFunc: cancelFunc, + store: governanceStore, + resolver: resolver, + tracker: tracker, + configStore: configStore, + modelCatalog: modelCatalog, + logger: logger, + inMemoryStore: inMemoryStore, + isVkMandatory: isVkMandatory, + } + return plugin, nil +} + // GetName returns the name of the plugin func (p *GovernancePlugin) GetName() string { return PluginName @@ -385,7 +460,7 @@ func (p *GovernancePlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bif Type: bifrost.Ptr("virtual_key_required"), StatusCode: bifrost.Ptr(400), Error: &schemas.ErrorField{ - Message: "x-bf-vk header is missing", + Message: "virtual key is missing in headers", }, }, }, nil @@ -596,6 +671,6 @@ func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provi } // GetGovernanceStore returns the governance store -func (p *GovernancePlugin) GetGovernanceStore() *GovernanceStore { +func (p *GovernancePlugin) GetGovernanceStore() GovernanceStore { return p.store } diff --git a/plugins/governance/provider_budget_test.go b/plugins/governance/provider_budget_test.go new file mode 100644 index 000000000..a4096d764 --- /dev/null +++ b/plugins/governance/provider_budget_test.go @@ -0,0 +1,236 @@ +package governance + +import ( + "strconv" + "testing" +) + +// TestProviderBudgetExceeded tests provider-specific budgets within a VK by making requests until budget is consumed +func TestProviderBudgetExceeded(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with different budgets for different providers + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-provider-budget-" + generateRandomID(), + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High overall budget + ResetDuration: "1h", + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: 0.01, // Specific OpenAI budget + ResetDuration: "1h", + }, + }, + { + Provider: "anthropic", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: 0.01, // Specific Anthropic budget + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with OpenAI budget $0.01 and Anthropic budget $0.01") + + // Test OpenAI provider budget exceeded + t.Run("OpenAIProviderBudgetExceeded", func(t *testing.T) { + providerBudget := 0.01 + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + + for requestNum <= 50 { + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "provider") { + t.Logf("Request %d correctly rejected: OpenAI provider budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, providerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, providerBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= providerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, providerBudget) + }) + + // Test Anthropic provider budget exceeded + t.Run("AnthropicProviderBudgetExceeded", func(t *testing.T) { + providerBudget := 0.01 + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + + for requestNum <= 50 { + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "anthropic/claude-3-7-sonnet-20250219", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "provider") { + t.Logf("Request %d correctly rejected: Anthropic provider budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, providerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("anthropic/claude-3-7-sonnet-20250219", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, providerBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= providerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, providerBudget) + }) +} diff --git a/plugins/governance/rate_limit_enforcement_test.go b/plugins/governance/rate_limit_enforcement_test.go new file mode 100644 index 000000000..859c2f63f --- /dev/null +++ b/plugins/governance/rate_limit_enforcement_test.go @@ -0,0 +1,615 @@ +package governance + +import ( + "testing" + "time" +) + +// TestVirtualKeyTokenRateLimitEnforcement verifies VK token rate limits actually reject requests +// Rate limit enforcement is POST-HOC: the request that exceeds the limit is ALLOWED, +// but subsequent requests are BLOCKED. +func TestVirtualKeyTokenRateLimitEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a VERY restrictive token rate limit + vkName := "test-vk-strict-token-limit-" + generateRandomID() + tokenLimit := int64(100) // Only 100 tokens max + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with strict token limit: %d tokens per %s", tokenLimit, tokenResetDuration) + + // Verify rate limit is in in-memory store + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + rateLimitID, _ := vkData["rate_limit_id"].(string) + + if rateLimitID == "" { + t.Fatalf("Rate limit not configured on VK") + } + + t.Logf("Rate limit ID %s configured on VK ✓", rateLimitID) + + // Make requests until token limit is exceeded + // Rate limit enforcement is POST-HOC: request that exceeds is allowed, next is blocked + consumedTokens := int64(0) + requestNum := 1 + shouldStop := false + + for requestNum <= 20 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hello how are you?", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request rejected - check if it's due to rate limit + if resp.StatusCode == 429 || CheckErrorMessage(t, resp, "token") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected: token limit exceeded at %d/%d", requestNum, consumedTokens, tokenLimit) + + // Verify rejection happened after exceeding the limit + if consumedTokens < tokenLimit { + t.Fatalf("Request rejected before token limit was exceeded: consumed %d < limit %d", consumedTokens, tokenLimit) + } + + t.Logf("Token rate limit enforcement verified ✓") + t.Logf("Request blocked after token limit exceeded") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not rate limit): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract token usage + var tokensUsed int64 + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if total, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int64(total) + } + } + + consumedTokens += tokensUsed + t.Logf("Request %d succeeded: tokens=%d, consumed=%d/%d", requestNum, tokensUsed, consumedTokens, tokenLimit) + + requestNum++ + + if shouldStop { + break + } + + if consumedTokens >= tokenLimit { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit token rate limit (consumed %d / %d) - rate limit not being enforced", + requestNum-1, consumedTokens, tokenLimit) +} + +// TestVirtualKeyRequestRateLimitEnforcement verifies VK request rate limits actually reject requests +func TestVirtualKeyRequestRateLimitEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a very restrictive request rate limit + vkName := "test-vk-strict-request-limit-" + generateRandomID() + requestLimit := int64(1) // Only 1 request allowed + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with request limit: %d request per %s", requestLimit, requestResetDuration) + + // Make requests until request limit is exceeded + requestCount := int64(0) + requestNum := 1 + + for requestNum <= 10 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request to test request rate limit.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request rejected - check if it's due to rate limit + if resp.StatusCode == 429 || CheckErrorMessage(t, resp, "request") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected: request limit exceeded at %d/%d", requestNum, requestCount, requestLimit) + + // Verify rejection happened after exceeding the limit + if requestCount < requestLimit { + t.Fatalf("Request rejected before request limit was exceeded: count %d < limit %d", requestCount, requestLimit) + } + + t.Logf("Request rate limit enforcement verified ✓") + t.Logf("Request blocked after request limit exceeded") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not rate limit): %v", requestNum, resp.Body) + } + } + + // Request succeeded - increment count + requestCount++ + t.Logf("Request %d succeeded: count=%d/%d", requestNum, requestCount, requestLimit) + + requestNum++ + } + + t.Fatalf("Made %d requests but never hit request rate limit (count %d / %d) - rate limit not being enforced", + requestNum-1, requestCount, requestLimit) +} + +// TestProviderConfigTokenRateLimitEnforcement verifies provider-level token limits reject requests +func TestProviderConfigTokenRateLimitEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with provider-level token rate limit + vkName := "test-vk-provider-strict-token-" + generateRandomID() + providerTokenLimit := int64(100) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &providerTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider token limit: %d tokens", providerTokenLimit) + + // Verify provider config rate limit is set + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + providerConfigs, _ := vkData["provider_configs"].([]interface{}) + + if len(providerConfigs) == 0 { + t.Fatalf("Provider config not found") + } + + t.Logf("Provider config rate limit configured ✓") + + // Make requests until provider token limit is exceeded + // Rate limit enforcement is POST-HOC: request that exceeds is allowed, next is blocked + consumedTokens := int64(0) + requestNum := 1 + shouldStop := false + + for requestNum <= 20 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request to openai to test provider token limit.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request rejected - check if it's due to rate limit + if resp.StatusCode == 429 || CheckErrorMessage(t, resp, "token") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected: provider token limit exceeded at %d/%d", requestNum, consumedTokens, providerTokenLimit) + + // Verify rejection happened after exceeding the limit + if consumedTokens < providerTokenLimit { + t.Fatalf("Request rejected before provider token limit was exceeded: consumed %d < limit %d", consumedTokens, providerTokenLimit) + } + + t.Logf("Provider token rate limit enforcement verified ✓") + t.Logf("Request blocked after provider token limit exceeded") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not rate limit): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract token usage + var tokensUsed int64 + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if total, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int64(total) + } + } + + consumedTokens += tokensUsed + t.Logf("Request %d succeeded: tokens=%d, consumed=%d/%d", requestNum, tokensUsed, consumedTokens, providerTokenLimit) + + requestNum++ + + if shouldStop { + break + } + + if consumedTokens >= providerTokenLimit { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider token rate limit (consumed %d / %d) - rate limit not being enforced", + requestNum-1, consumedTokens, providerTokenLimit) +} + +// TestProviderConfigRequestRateLimitEnforcement verifies provider-level request limits +func TestProviderConfigRequestRateLimitEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with provider-level request rate limit + vkName := "test-vk-provider-strict-request-" + generateRandomID() + providerRequestLimit := int64(1) // Only 1 request allowed + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &providerRequestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider request limit: %d request", providerRequestLimit) + + // Make requests until provider request limit is exceeded + requestCount := int64(0) + requestNum := 1 + + for requestNum <= 10 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request to test provider request rate limit.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request rejected - check if it's due to rate limit + if resp.StatusCode == 429 || CheckErrorMessage(t, resp, "request") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected: provider request limit exceeded at %d/%d", requestNum, requestCount, providerRequestLimit) + + // Verify rejection happened after exceeding the limit + if requestCount < providerRequestLimit { + t.Fatalf("Request rejected before provider request limit was exceeded: count %d < limit %d", requestCount, providerRequestLimit) + } + + t.Logf("Provider request rate limit enforcement verified ✓") + t.Logf("Request blocked after provider request limit exceeded") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not rate limit): %v", requestNum, resp.Body) + } + } + + // Request succeeded - increment count + requestCount++ + t.Logf("Request %d succeeded: count=%d/%d", requestNum, requestCount, providerRequestLimit) + + requestNum++ + } + + t.Fatalf("Made %d requests but never hit provider request rate limit (count %d / %d) - rate limit not being enforced", + requestNum-1, requestCount, providerRequestLimit) +} + +// TestProviderAndVKRateLimitBothEnforced verifies both provider and VK limits are enforced +func TestProviderAndVKRateLimitBothEnforced(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with both VK and provider request limits + vkName := "test-vk-both-enforced-" + generateRandomID() + vkRequestLimit := int64(5) + providerRequestLimit := int64(2) // More restrictive + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &vkRequestLimit, + RequestResetDuration: &requestResetDuration, + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &providerRequestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with VK limit (%d) and provider limit (%d requests)", vkRequestLimit, providerRequestLimit) + + // Make requests - provider limit (2) is more restrictive than VK limit (5) + // So we should hit provider limit first + successCount := 0 + for i := 0; i < 5; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request " + string(rune('0'+i)) + " to test both limits.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded (count: %d)", i+1, successCount) + } else if resp.StatusCode >= 400 { + t.Logf("Request %d rejected with status %d", i+1, resp.StatusCode) + if successCount < int(providerRequestLimit) { + t.Fatalf("Request rejected before provider limit (%d): %v", providerRequestLimit, resp.Body) + } + // Expected - hit provider limit first + return + } + } + + if successCount > 0 { + if successCount >= 5 { + t.Fatalf("Made all %d requests without hitting rate limit (provider limit was %d) - rate limit not enforced", + successCount, providerRequestLimit) + } + t.Logf("Both VK and provider rate limits are configured and enforced ✓") + } else { + t.Skip("Could not test - all requests failed") + } +} + +// TestRateLimitInMemoryUsageTracking verifies usage counters are tracked in in-memory store +func TestRateLimitInMemoryUsageTracking(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with rate limit + vkName := "test-vk-usage-tracking-" + generateRandomID() + tokenLimit := int64(10000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK for usage tracking test") + + // Make a request + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test for usage tracking.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not execute request for usage tracking test") + } + + // Get usage from response + var tokensUsed int + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if total, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int(total) + } + } + + if tokensUsed == 0 { + t.Skip("Could not extract token usage from response") + } + + t.Logf("Request used %d tokens", tokensUsed) + + // Wait for async update + time.Sleep(1 * time.Second) + + // Verify rate limit usage is tracked in in-memory store + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap, ok := getDataResp.Body["virtual_keys"].(map[string]interface{}) + if !ok || virtualKeysMap == nil { + t.Fatalf("Virtual keys field missing or not a map in get response") + } + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + rateLimitID, _ := vkData["rate_limit_id"].(string) + + if rateLimitID != "" { + t.Logf("Rate limit %s is configured and tracking usage ✓", rateLimitID) + } else { + t.Logf("Rate limit is configured ✓") + } +} diff --git a/plugins/governance/rate_limit_test.go b/plugins/governance/rate_limit_test.go new file mode 100644 index 000000000..8a5b4c815 --- /dev/null +++ b/plugins/governance/rate_limit_test.go @@ -0,0 +1,991 @@ +package governance + +import ( + "testing" + "time" +) + +// TestVirtualKeyTokenRateLimit tests that VK-level token rate limits are enforced +func TestVirtualKeyTokenRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a very restrictive token rate limit + vkName := "test-vk-token-limit-" + generateRandomID() + tokenLimit := int64(500) // Only 500 tokens per hour + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with token limit: %d tokens per %s", vkName, tokenLimit, tokenResetDuration) + + // Make requests until we hit the token limit + successCount := 0 + for i := 0; i < 10; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Short test request " + string(rune('0'+i)) + " for token limit.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "token") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected due to token rate limit", i+1) + return // Test passed - hit the token limit + } else { + t.Logf("Request %d failed with unexpected error: %v", i+1, resp.Body) + } + } else if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded (tokens within limit)", i+1) + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests before hitting token limit ✓", successCount) + } else { + t.Skip("Could not make requests to test token limit") + } +} + +// TestVirtualKeyRequestRateLimit tests that VK-level request rate limits are enforced +func TestVirtualKeyRequestRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a very restrictive request rate limit + vkName := "test-vk-request-limit-" + generateRandomID() + requestLimit := int64(3) // Only 3 requests per minute + requestResetDuration := "1m" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with request limit: %d requests per %s", vkName, requestLimit, requestResetDuration) + + // Make requests until we hit the request limit + successCount := 0 + for i := 0; i < 5; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request number " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "request") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected due to request rate limit", i+1) + return // Test passed + } else { + t.Logf("Request %d failed with different error", i+1) + } + } else if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded (count: %d/%d)", i+1, successCount, requestLimit) + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests before hitting request limit ✓", successCount) + } else { + t.Skip("Could not make requests to test request limit") + } +} + +// TestProviderConfigTokenRateLimit tests that provider-level token rate limits are enforced +func TestProviderConfigTokenRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a provider config that has a token rate limit + vkName := "test-vk-provider-token-limit-" + generateRandomID() + providerTokenLimit := int64(300) // Limited tokens per provider + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &providerTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with provider token limit: %d tokens per %s", vkName, providerTokenLimit, tokenResetDuration) + + // Make requests to openai until we hit provider token limit + successCount := 0 + for i := 0; i < 10; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Provider token limit test " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "token") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected due to provider token limit", i+1) + return // Test passed + } else { + t.Logf("Request %d failed with different error", i+1) + } + } else if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded", i+1) + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests with provider token limit ✓", successCount) + } else { + t.Skip("Could not make requests to test provider token limit") + } +} + +// TestProviderConfigRequestRateLimit tests that provider-level request rate limits are enforced +func TestProviderConfigRequestRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a provider config that has a request rate limit + vkName := "test-vk-provider-request-limit-" + generateRandomID() + providerRequestLimit := int64(2) // Only 2 requests per minute for this provider + requestResetDuration := "1m" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &providerRequestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with provider request limit: %d requests per %s", vkName, providerRequestLimit, requestResetDuration) + + // Make requests to openai until we hit provider request limit + successCount := 0 + for i := 0; i < 5; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Provider request limit test " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "request") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected due to provider request limit", i+1) + return // Test passed + } else { + t.Logf("Request %d failed with different error", i+1) + } + } else if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded (count: %d/%d)", i+1, successCount, providerRequestLimit) + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests with provider request limit ✓", successCount) + } else { + t.Skip("Could not make requests to test provider request limit") + } +} + +// TestMultipleProvidersSeparateRateLimits tests that different providers have independent rate limits +func TestMultipleProvidersSeparateRateLimits(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with multiple providers, each with their own rate limits + vkName := "test-vk-multi-provider-limits-" + generateRandomID() + openaiLimit := int64(100) + anthropicLimit := int64(50) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &openaiLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + { + Provider: "anthropic", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &anthropicLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with separate rate limits per provider", vkName) + + // Verify both providers are allowed + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + + providerConfigs, _ := vkData["provider_configs"].([]interface{}) + if len(providerConfigs) != 2 { + t.Fatalf("Expected 2 provider configs, got %d", len(providerConfigs)) + } + + t.Logf("VK has %d provider configs with separate rate limits ✓", len(providerConfigs)) +} + +// TestProviderAndVKRateLimitTogether tests that both provider and VK rate limits are enforced together +func TestProviderAndVKRateLimitTogether(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both VK-level and provider-level rate limits + vkName := "test-vk-both-limits-" + generateRandomID() + vkTokenLimit := int64(1000) + vkTokenResetDuration := "1h" + providerTokenLimit := int64(300) + providerTokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &vkTokenLimit, + TokenResetDuration: &vkTokenResetDuration, + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &providerTokenLimit, + TokenResetDuration: &providerTokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with VK limit (%d tokens) and provider limit (%d tokens)", vkName, vkTokenLimit, providerTokenLimit) + + // Verify the VK has both limits configured + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + + // Check VK has rate limit + vkRateLimitID, _ := vkData["rate_limit_id"].(string) + if vkRateLimitID == "" { + t.Fatalf("VK rate limit ID not found") + } + + // Check provider config exists + providerConfigs, _ := vkData["provider_configs"].([]interface{}) + if len(providerConfigs) == 0 { + t.Fatalf("No provider configs found") + } + + t.Logf("VK has both VK-level rate limit and provider-level rate limit configured ✓") +} + +// TestRateLimitInMemorySync tests that rate limit changes sync to in-memory store +func TestRateLimitInMemorySync(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a token rate limit + vkName := "test-vk-rate-limit-sync-" + generateRandomID() + initialTokenLimit := int64(1000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &initialTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with rate limit: %d tokens", vkName, initialTokenLimit) + + // Get initial rate limit from in-memory store + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + rateLimitID, _ := vkData["rate_limit_id"].(string) + + if rateLimitID == "" { + t.Fatalf("Rate limit ID not found in VK") + } + + // Update the rate limit + newTokenLimit := int64(5000) + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &newTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK rate limit: status %d", updateResp.StatusCode) + } + + t.Logf("Updated rate limit from %d to %d tokens", initialTokenLimit, newTokenLimit) + + // Verify rate limit is updated in in-memory store + time.Sleep(500 * time.Millisecond) + + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp2.StatusCode != 200 { + t.Fatalf("Failed to get governance data after update: status %d", getDataResp2.StatusCode) + } + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + vkData2 := virtualKeysMap2[vkValue].(map[string]interface{}) + + // Verify VK still has rate limit configured + rateLimitID2, _ := vkData2["rate_limit_id"].(string) + if rateLimitID2 == "" { + t.Fatalf("Rate limit ID removed after update") + } + + // Verify it's the same rate limit (ID should match) + if rateLimitID2 != rateLimitID { + t.Fatalf("Rate limit ID changed after update: was %s, now %s", rateLimitID, rateLimitID2) + } + + // Verify rate limit content - check the actual values in the main RateLimits map + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + rateLimit2, ok := rateLimitsMap2[rateLimitID2].(map[string]interface{}) + if !ok { + t.Fatalf("Rate limit not found in RateLimits map") + } + + // Check TokenMaxLimit was updated + tokenMaxLimit, ok := rateLimit2["token_max_limit"].(float64) + if !ok { + t.Fatalf("Token max limit not found in rate limit") + } + if int64(tokenMaxLimit) != newTokenLimit { + t.Fatalf("Token max limit not updated: expected %d but got %d", newTokenLimit, int64(tokenMaxLimit)) + } + t.Logf("Token max limit correctly updated to %d ✓", int64(tokenMaxLimit)) + + // Check TokenResetDuration persists + resetDuration, ok := rateLimit2["token_reset_duration"].(string) + if !ok { + t.Fatalf("Token reset duration not found in rate limit") + } + if resetDuration != tokenResetDuration { + t.Fatalf("Token reset duration changed: expected %s but got %s", tokenResetDuration, resetDuration) + } + t.Logf("Token reset duration persisted: %s ✓", resetDuration) + + // Check usage counters exist + if tokenCurrentUsage, ok := rateLimit2["token_current_usage"].(float64); ok { + t.Logf("Token current usage in memory: %d", int64(tokenCurrentUsage)) + } + + t.Logf("Rate limit in-memory sync verified ✓") + t.Logf("VK rate limit ID persisted: %s", rateLimitID2) +} + +// TestRateLimitTokenAndRequestTogether tests that both token and request limits work together +func TestRateLimitTokenAndRequestTogether(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both token and request limits + vkName := "test-vk-token-and-request-" + generateRandomID() + tokenLimit := int64(5000) + tokenResetDuration := "1h" + requestLimit := int64(100) + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with token limit (%d) and request limit (%d)", vkName, tokenLimit, requestLimit) + + // Make a few requests and verify both limits are being tracked + successCount := 0 + for i := 0; i < 3; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request for token and request limits " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded", i+1) + } else if resp.StatusCode >= 400 { + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + break + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests with both token and request limits ✓", successCount) + } else { + t.Skip("Could not make requests to test combined limits") + } +} + +// TestRateLimitUsageTrackedInMemory tests that VK-level rate limit usage is tracked in in-memory store +func TestRateLimitUsageTrackedInMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both token and request rate limits + vkName := "test-vk-usage-tracking-" + generateRandomID() + tokenLimit := int64(100000) + tokenResetDuration := "1h" + requestLimit := int64(100) + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with rate limits for usage tracking", vkName) + + // Get initial state - rate limit usage should be 0 + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + rateLimitID1, _ := vkData1["rate_limit_id"].(string) + + initialTokenUsage := 0.0 + initialRequestUsage := 0.0 + + // Check initial rate limit usage (should be 0) from main RateLimits map + getRateLimitsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap1 := getRateLimitsResp1.Body["rate_limits"].(map[string]interface{}) + rateLimit1, ok := rateLimitsMap1[rateLimitID1].(map[string]interface{}) + if !ok { + t.Fatalf("Rate limit not found in RateLimits map") + } + + if tokenUsage, ok := rateLimit1["token_current_usage"].(float64); ok { + initialTokenUsage = tokenUsage + t.Logf("Initial token usage: %d", int64(initialTokenUsage)) + } + if requestUsage, ok := rateLimit1["request_current_usage"].(float64); ok { + initialRequestUsage = requestUsage + t.Logf("Initial request usage: %d", int64(initialRequestUsage)) + } + + // Make a request to use some tokens and increment request count + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request for usage tracking.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to test usage tracking") + } + + // Wait for async update to in-memory store + time.Sleep(500 * time.Millisecond) + + // Get updated state - rate limit usage should have increased + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + vkData2 := virtualKeysMap2[vkValue].(map[string]interface{}) + rateLimitID2, _ := vkData2["rate_limit_id"].(string) + + // Get rate limit from main RateLimits map + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + rateLimit2, ok := rateLimitsMap2[rateLimitID2].(map[string]interface{}) + if !ok { + t.Fatalf("Rate limit not found in RateLimits map after request") + } + + // Check that token usage increased + tokenUsage2, ok := rateLimit2["token_current_usage"].(float64) + if !ok { + t.Fatalf("Token current usage not found in rate limit") + } + + if tokenUsage2 <= initialTokenUsage { + t.Logf("Warning: Token usage did not increase (before: %d, after: %d)", int64(initialTokenUsage), int64(tokenUsage2)) + } else { + t.Logf("Token usage increased from %d to %d ✓", int64(initialTokenUsage), int64(tokenUsage2)) + } + + // Check that request usage increased + requestUsage2, ok := rateLimit2["request_current_usage"].(float64) + if !ok { + t.Fatalf("Request current usage not found in rate limit") + } + + if requestUsage2 <= initialRequestUsage { + t.Logf("Warning: Request usage did not increase (before: %d, after: %d)", int64(initialRequestUsage), int64(requestUsage2)) + } else { + t.Logf("Request usage increased from %d to %d ✓", int64(initialRequestUsage), int64(requestUsage2)) + } + + // Verify rate limit still has the configured max limits + tokenMaxLimit, ok := rateLimit2["token_max_limit"].(float64) + if ok && int64(tokenMaxLimit) != tokenLimit { + t.Fatalf("Token max limit changed: expected %d but got %d", tokenLimit, int64(tokenMaxLimit)) + } + + requestMaxLimit, ok := rateLimit2["request_max_limit"].(float64) + if ok && int64(requestMaxLimit) != requestLimit { + t.Fatalf("Request max limit changed: expected %d but got %d", requestLimit, int64(requestMaxLimit)) + } + + t.Logf("VK-level rate limit usage properly tracked in in-memory store ✓") + t.Logf("Token usage: %d/%d, Request usage: %d/%d", + int64(tokenUsage2), tokenLimit, int64(requestUsage2), requestLimit) +} + +// TestProviderLevelRateLimitUsageTracking tests that provider-level rate limits are separately tracked +func TestProviderLevelRateLimitUsageTracking(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with multiple providers, each with their own rate limits + vkName := "test-vk-provider-usage-" + generateRandomID() + openaiTokenLimit := int64(50000) + anthropicTokenLimit := int64(30000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &openaiTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + { + Provider: "anthropic", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &anthropicTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with per-provider rate limits", vkName) + + // Get initial state - provider rate limit usage should be 0 + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + + providerConfigs1, ok := vkData1["provider_configs"].([]interface{}) + if !ok { + t.Fatalf("Provider configs not found in VK data") + } + + if len(providerConfigs1) != 2 { + t.Fatalf("Expected 2 provider configs, got %d", len(providerConfigs1)) + } + + t.Logf("VK has %d provider configs with separate rate limits", len(providerConfigs1)) + + // Make a request with openai model to use openai provider's rate limit + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request for provider rate limit tracking.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to test provider rate limit tracking") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Get updated state - openai provider rate limit usage should have increased + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + vkData2 := virtualKeysMap2[vkValue].(map[string]interface{}) + + providerConfigs2, ok := vkData2["provider_configs"].([]interface{}) + if !ok { + t.Fatalf("Provider configs not found in VK data after request") + } + + // Check each provider config for rate limit updates + var openaiUsage, anthropicUsage float64 + var openaiMaxLimit, anthropicMaxLimit float64 + + // Get rate limits from main RateLimits map + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + + for i, providerConfig := range providerConfigs2 { + config, ok := providerConfig.(map[string]interface{}) + if !ok { + continue + } + + provider, ok := config["provider"].(string) + if !ok { + continue + } + + rateLimitID, ok := config["rate_limit_id"].(string) + if !ok { + t.Logf("Provider %s: No rate limit ID found", provider) + continue + } + + rateLimit, ok := rateLimitsMap2[rateLimitID].(map[string]interface{}) + if !ok { + t.Logf("Provider %s: No rate limit found in RateLimits map", provider) + continue + } + + tokenUsage, _ := rateLimit["token_current_usage"].(float64) + tokenMaxLimit, _ := rateLimit["token_max_limit"].(float64) + + if provider == "openai" { + openaiUsage = tokenUsage + openaiMaxLimit = tokenMaxLimit + t.Logf("Provider %d (openai): Token usage: %d/%d", i, int64(tokenUsage), int64(tokenMaxLimit)) + } else if provider == "anthropic" { + anthropicUsage = tokenUsage + anthropicMaxLimit = tokenMaxLimit + t.Logf("Provider %d (anthropic): Token usage: %d/%d", i, int64(tokenUsage), int64(tokenMaxLimit)) + } + } + + // Verify provider limits are independent + if openaiMaxLimit != float64(openaiTokenLimit) { + t.Logf("Warning: OpenAI max limit changed: expected %d but got %d", openaiTokenLimit, int64(openaiMaxLimit)) + } + + if anthropicMaxLimit != float64(anthropicTokenLimit) { + t.Logf("Warning: Anthropic max limit changed: expected %d but got %d", anthropicTokenLimit, int64(anthropicMaxLimit)) + } + + t.Logf("Provider-level rate limits properly tracked separately in in-memory store ✓") + t.Logf("OpenAI usage: %d, Anthropic usage: %d (separate limits)", int64(openaiUsage), int64(anthropicUsage)) +} diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index 4518a1ba4..e37f92a97 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "slices" - "strings" "time" "github.com/maximhq/bifrost/core/schemas" @@ -63,12 +62,12 @@ type UsageInfo struct { // BudgetResolver provides decision logic for the new hierarchical governance system type BudgetResolver struct { - store *GovernanceStore + store GovernanceStore logger schemas.Logger } // NewBudgetResolver creates a new budget-based governance resolver -func NewBudgetResolver(store *GovernanceStore, logger schemas.Logger) *BudgetResolver { +func NewBudgetResolver(store GovernanceStore, logger schemas.Logger) *BudgetResolver { return &BudgetResolver{ store: store, logger: logger, @@ -127,13 +126,13 @@ func (r *BudgetResolver) EvaluateRequest(ctx *schemas.BifrostContext, evaluation } } - // 4. Check rate limits (Provider level first, then VK level) - if rateLimitResult := r.checkRateLimits(vk, string(evaluationRequest.Provider)); rateLimitResult != nil { + // 4. Check rate limits hierarchy (Provider level first, then VK level) + if rateLimitResult := r.checkRateLimitHierarchy(ctx, vk, string(evaluationRequest.Provider), evaluationRequest.Model, evaluationRequest.RequestID); rateLimitResult != nil { return rateLimitResult } // 5. Check budget hierarchy (VK → Team → Customer) - if budgetResult := r.checkBudgetHierarchy(ctx, vk, evaluationRequest.Provider); budgetResult != nil { + if budgetResult := r.checkBudgetHierarchy(ctx, vk, evaluationRequest); budgetResult != nil { return budgetResult } @@ -192,77 +191,25 @@ func (r *BudgetResolver) isProviderAllowed(vk *configstoreTables.TableVirtualKey return false } -// checkRateLimits checks provider-level rate limits first, then VK rate limits using flexible approach -func (r *BudgetResolver) checkRateLimits(vk *configstoreTables.TableVirtualKey, provider string) *EvaluationResult { - // First check provider-level rate limits - if providerRateLimitResult := r.checkProviderRateLimits(vk, provider); providerRateLimitResult != nil { - return providerRateLimitResult - } - - // Then check VK-level rate limits - if vk.RateLimit == nil { - return nil // No VK rate limits defined - } - - return r.checkSingleRateLimit(vk.RateLimit, "virtual key", vk) -} - -// checkProviderRateLimits checks rate limits for a specific provider config -func (r *BudgetResolver) checkProviderRateLimits(vk *configstoreTables.TableVirtualKey, provider string) *EvaluationResult { - if vk.ProviderConfigs == nil { - return nil // No provider configs defined - } - - // Find the specific provider config - for _, pc := range vk.ProviderConfigs { - if pc.Provider == provider && pc.RateLimit != nil { - return r.checkSingleRateLimit(pc.RateLimit, fmt.Sprintf("provider '%s'", provider), vk) - } - } - - return nil // No rate limits for this provider -} - -// checkSingleRateLimit checks a single rate limit and returns evaluation result if violated -func (r *BudgetResolver) checkSingleRateLimit(rateLimit *configstoreTables.TableRateLimit, rateLimitName string, vk *configstoreTables.TableVirtualKey) *EvaluationResult { - var violations []string - - // Token limits - if rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage >= *rateLimit.TokenMaxLimit { - duration := "unknown" - if rateLimit.TokenResetDuration != nil { - duration = *rateLimit.TokenResetDuration - } - violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", - rateLimit.TokenCurrentUsage, *rateLimit.TokenMaxLimit, duration)) - } - - // Request limits - if rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage >= *rateLimit.RequestMaxLimit { - duration := "unknown" - if rateLimit.RequestResetDuration != nil { - duration = *rateLimit.RequestResetDuration - } - violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", - rateLimit.RequestCurrentUsage, *rateLimit.RequestMaxLimit, duration)) - } - - if len(violations) > 0 { - // Determine specific violation type - decision := DecisionRateLimited - if len(violations) == 1 { - if strings.Contains(violations[0], "token") { - decision = DecisionTokenLimited - } else if strings.Contains(violations[0], "request") { - decision = DecisionRequestLimited +// checkRateLimitHierarchy checks provider-level rate limits first, then VK rate limits using flexible approach +func (r *BudgetResolver) checkRateLimitHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider string, model string, requestID string) *EvaluationResult { + if decision, err := r.store.CheckRateLimit(ctx, vk, schemas.ModelProvider(provider), model, requestID, nil, nil); err != nil { + // Check provider-level first (matching check order), then VK-level + var rateLimitInfo *configstoreTables.TableRateLimit + for _, pc := range vk.ProviderConfigs { + if pc.Provider == provider && pc.RateLimit != nil { + rateLimitInfo = pc.RateLimit + break } } - + if rateLimitInfo == nil && vk.RateLimit != nil { + rateLimitInfo = vk.RateLimit + } return &EvaluationResult{ Decision: decision, - Reason: fmt.Sprintf("%s rate limits exceeded: %v", rateLimitName, violations), + Reason: fmt.Sprintf("Rate limit check failed: %s", err.Error()), VirtualKey: vk, - RateLimitInfo: rateLimit, + RateLimitInfo: rateLimitInfo, } } @@ -270,14 +217,14 @@ func (r *BudgetResolver) checkSingleRateLimit(rateLimit *configstoreTables.Table } // checkBudgetHierarchy checks the budget hierarchy atomically (VK → Team → Customer) -func (r *BudgetResolver) checkBudgetHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) *EvaluationResult { +func (r *BudgetResolver) checkBudgetHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest) *EvaluationResult { // Use atomic budget checking to prevent race conditions - if err := r.store.CheckBudget(ctx, vk, provider); err != nil { - r.logger.Debug(fmt.Sprintf("Atomic budget check failed for VK %s: %s", vk.ID, err.Error())) + if err := r.store.CheckBudget(ctx, vk, request, nil); err != nil { + r.logger.Debug(fmt.Sprintf("Atomic budget exceeded for VK %s: %s", vk.ID, err.Error())) return &EvaluationResult{ Decision: DecisionBudgetExceeded, - Reason: fmt.Sprintf("Budget check failed: %s", err.Error()), + Reason: fmt.Sprintf("Budget exceeded: %s", err.Error()), VirtualKey: vk, } } diff --git a/plugins/governance/resolver_test.go b/plugins/governance/resolver_test.go new file mode 100644 index 000000000..3b960e759 --- /dev/null +++ b/plugins/governance/resolver_test.go @@ -0,0 +1,551 @@ +package governance + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestBudgetResolver_EvaluateRequest_AllowedRequest tests happy path +func TestBudgetResolver_EvaluateRequest_AllowedRequest(t *testing.T) { + logger := NewMockLogger() + vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", true) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + RequestID: "req-123", + }) + + assertDecision(t, DecisionAllow, result) + assertVirtualKeyFound(t, result) +} + +// TestBudgetResolver_EvaluateRequest_VirtualKeyNotFound tests missing VK +func TestBudgetResolver_EvaluateRequest_VirtualKeyNotFound(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-nonexistent", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionVirtualKeyNotFound, result) +} + +// TestBudgetResolver_EvaluateRequest_VirtualKeyBlocked tests inactive VK +func TestBudgetResolver_EvaluateRequest_VirtualKeyBlocked(t *testing.T) { + logger := NewMockLogger() + vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", false) // Inactive + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionVirtualKeyBlocked, result) +} + +// TestBudgetResolver_EvaluateRequest_ProviderBlocked tests provider filtering +func TestBudgetResolver_EvaluateRequest_ProviderBlocked(t *testing.T) { + logger := NewMockLogger() + + // VK with only Anthropic allowed + providerConfigs := []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("anthropic", []string{"claude-3-sonnet"}), + } + vk := buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test VK", providerConfigs) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + // Try to use OpenAI (not allowed) + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionProviderBlocked, result) + assertVirtualKeyFound(t, result) +} + +// TestBudgetResolver_EvaluateRequest_ModelBlocked tests model filtering +func TestBudgetResolver_EvaluateRequest_ModelBlocked(t *testing.T) { + logger := NewMockLogger() + + // VK with specific models allowed + providerConfigs := []configstoreTables.TableVirtualKeyProviderConfig{ + { + Provider: "openai", + AllowedModels: []string{"gpt-4", "gpt-4-turbo"}, // Only these models + Weight: 1.0, + RateLimit: nil, + Budget: nil, + Keys: []configstoreTables.TableKey{}, + }, + } + vk := buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test VK", providerConfigs) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + // Try to use gpt-4o-mini (not in allowed list) + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + }) + + assertDecision(t, DecisionModelBlocked, result) +} + +// TestBudgetResolver_EvaluateRequest_RateLimitExceeded_TokenLimit tests token limit +func TestBudgetResolver_EvaluateRequest_RateLimitExceeded_TokenLimit(t *testing.T) { + logger := NewMockLogger() + + // VK with rate limit already at max + rateLimit := buildRateLimitWithUsage("rl1", 10000, 10000, 1000, 0) // Tokens at max + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionTokenLimited, result) + assertRateLimitInfo(t, result) +} + +// TestBudgetResolver_EvaluateRequest_RateLimitExceeded_RequestLimit tests request limit +func TestBudgetResolver_EvaluateRequest_RateLimitExceeded_RequestLimit(t *testing.T) { + logger := NewMockLogger() + + // VK with request limit already at max + rateLimit := buildRateLimitWithUsage("rl1", 10000, 0, 100, 100) // Requests at max + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionRequestLimited, result) +} + +// TestBudgetResolver_EvaluateRequest_RateLimitExpired tests rate limit reset +func TestBudgetResolver_EvaluateRequest_RateLimitExpired(t *testing.T) { + logger := NewMockLogger() + + // VK with rate limit that's expired (should be treated as reset) + duration := "1m" + rateLimit := &configstoreTables.TableRateLimit{ + ID: "rl1", + TokenMaxLimit: ptrInt64(10000), + TokenCurrentUsage: 10000, // At limit + TokenResetDuration: &duration, + TokenLastReset: time.Now().Add(-2 * time.Minute), // Expired + RequestMaxLimit: ptrInt64(1000), + RequestCurrentUsage: 0, + RequestResetDuration: &duration, + RequestLastReset: time.Now(), + } + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + // Reset expired rate limits (simulating ticker behavior) + expiredRateLimits := store.ResetExpiredRateLimitsInMemory(context.Background()) + err = store.ResetExpiredRateLimits(context.Background(), expiredRateLimits) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + // Should allow because rate limit was expired and has been reset + assertDecision(t, DecisionAllow, result) +} + +// TestBudgetResolver_EvaluateRequest_BudgetExceeded tests budget violation +func TestBudgetResolver_EvaluateRequest_BudgetExceeded(t *testing.T) { + logger := NewMockLogger() + + budget := buildBudgetWithUsage("budget1", 100.0, 100.0, "1d") // At limit + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionBudgetExceeded, result) +} + +// TestBudgetResolver_EvaluateRequest_BudgetExpired tests expired budget (should be treated as reset) +func TestBudgetResolver_EvaluateRequest_BudgetExpired(t *testing.T) { + logger := NewMockLogger() + + budget := &configstoreTables.TableBudget{ + ID: "budget1", + MaxLimit: 100.0, + CurrentUsage: 100.0, // At limit + ResetDuration: "1d", + LastReset: time.Now().Add(-48 * time.Hour), // Expired + } + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + // Should allow because budget is expired (will be reset) + assertDecision(t, DecisionAllow, result) +} + +// TestBudgetResolver_EvaluateRequest_MultiLevelBudgetHierarchy tests hierarchy checking +func TestBudgetResolver_EvaluateRequest_MultiLevelBudgetHierarchy(t *testing.T) { + logger := NewMockLogger() + + vkBudget := buildBudgetWithUsage("vk-budget", 100.0, 50.0, "1d") + teamBudget := buildBudgetWithUsage("team-budget", 500.0, 200.0, "1d") + customerBudget := buildBudgetWithUsage("customer-budget", 1000.0, 400.0, "1d") + + team := buildTeam("team1", "Team 1", teamBudget) + customer := buildCustomer("customer1", "Customer 1", customerBudget) + team.CustomerID = &customer.ID + team.Customer = customer + + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", vkBudget) + vk.TeamID = &team.ID + vk.Team = team + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*vkBudget, *teamBudget, *customerBudget}, + Teams: []configstoreTables.TableTeam{*team}, + Customers: []configstoreTables.TableCustomer{*customer}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + // Test: All under limit should pass + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + assertDecision(t, DecisionAllow, result) + + // Test: VK budget exceeds should fail + // Get the governance data to update the budget directly + governanceData := store.GetGovernanceData() + vkBudgetToUpdate := governanceData.Budgets["vk-budget"] + if vkBudgetToUpdate != nil { + vkBudgetToUpdate.CurrentUsage = 100.0 + store.budgets.Store("vk-budget", vkBudgetToUpdate) + } + result = resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + assertDecision(t, DecisionBudgetExceeded, result) +} + +// TestBudgetResolver_EvaluateRequest_ProviderLevelRateLimit tests provider-specific rate limits +func TestBudgetResolver_EvaluateRequest_ProviderLevelRateLimit(t *testing.T) { + logger := NewMockLogger() + + // Provider with rate limit at max + providerRL := buildRateLimitWithUsage("provider-rl", 5000, 5000, 500, 0) + providerConfig := buildProviderConfigWithRateLimit("openai", []string{"gpt-4"}, providerRL) + vk := buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test VK", []configstoreTables.TableVirtualKeyProviderConfig{providerConfig}) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*providerRL}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionTokenLimited, result) + assertRateLimitInfo(t, result) +} + +// TestBudgetResolver_CheckRateLimits_BothExceeded tests token and request limits simultaneously +func TestBudgetResolver_CheckRateLimits_BothExceeded(t *testing.T) { + logger := NewMockLogger() + + // Rate limit with both token and request at max + rateLimit := buildRateLimitWithUsage("rl1", 1000, 1000, 100, 100) + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionRateLimited, result) + assert.Contains(t, result.Reason, "rate limit") +} + +// TestBudgetResolver_IsProviderAllowed tests provider filtering logic +func TestBudgetResolver_IsProviderAllowed(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + + tests := []struct { + name string + vk *configstoreTables.TableVirtualKey + provider schemas.ModelProvider + shouldBeAllowed bool + }{ + { + name: "No provider configs (all allowed)", + vk: buildVirtualKey("vk1", "sk-bf-test", "Test", true), + provider: schemas.OpenAI, + shouldBeAllowed: true, + }, + { + name: "Provider in allowlist", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"gpt-4"}), + }), + provider: schemas.OpenAI, + shouldBeAllowed: true, + }, + { + name: "Provider not in allowlist", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("anthropic", []string{"claude-3-sonnet"}), + }), + provider: schemas.OpenAI, + shouldBeAllowed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + allowed := resolver.isProviderAllowed(tt.vk, tt.provider) + assert.Equal(t, tt.shouldBeAllowed, allowed) + }) + } +} + +// TestBudgetResolver_IsModelAllowed tests model filtering logic +func TestBudgetResolver_IsModelAllowed(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + + tests := []struct { + name string + vk *configstoreTables.TableVirtualKey + provider schemas.ModelProvider + model string + shouldBeAllowed bool + }{ + { + name: "No provider configs (all models allowed)", + vk: buildVirtualKey("vk1", "sk-bf-test", "Test", true), + provider: schemas.OpenAI, + model: "gpt-4", + shouldBeAllowed: true, + }, + { + name: "Empty allowed models (all models allowed)", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{}), // Empty = all allowed + }), + provider: schemas.OpenAI, + model: "gpt-4", + shouldBeAllowed: true, + }, + { + name: "Model in allowlist", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"gpt-4", "gpt-4-turbo"}), + }), + provider: schemas.OpenAI, + model: "gpt-4", + shouldBeAllowed: true, + }, + { + name: "Model not in allowlist", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"gpt-4", "gpt-4-turbo"}), + }), + provider: schemas.OpenAI, + model: "gpt-4o-mini", + shouldBeAllowed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + allowed := resolver.isModelAllowed(tt.vk, tt.provider, tt.model) + assert.Equal(t, tt.shouldBeAllowed, allowed) + }) + } +} + +// TestBudgetResolver_ContextPopulation tests context values are set correctly +func TestBudgetResolver_ContextPopulation(t *testing.T) { + logger := NewMockLogger() + vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", true) + customer := buildCustomer("cust1", "Customer 1", nil) + team := buildTeam("team1", "Team 1", nil) + team.CustomerID = &customer.ID + team.Customer = customer + vk.TeamID = &team.ID + vk.Team = team + vk.CustomerID = &customer.ID + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Teams: []configstoreTables.TableTeam{*team}, + Customers: []configstoreTables.TableCustomer{*customer}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assert.Equal(t, DecisionAllow, result.Decision) + + // Check context was populated + vkID, _ := ctx.Value(schemas.BifrostContextKey("bf-governance-virtual-key-id")).(string) + teamID, _ := ctx.Value(schemas.BifrostContextKey("bf-governance-team-id")).(string) + customerID, _ := ctx.Value(schemas.BifrostContextKey("bf-governance-customer-id")).(string) + + assert.Equal(t, "vk1", vkID) + assert.Equal(t, "team1", teamID) + assert.Equal(t, "cust1", customerID) +} diff --git a/plugins/governance/store.go b/plugins/governance/store.go index f82240fb0..c3ef36fe3 100644 --- a/plugins/governance/store.go +++ b/plugins/governance/store.go @@ -4,6 +4,7 @@ package governance import ( "context" "fmt" + "strings" "sync" "time" @@ -11,16 +12,16 @@ import ( "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "gorm.io/gorm" - "gorm.io/gorm/clause" ) -// GovernanceStore provides in-memory cache for governance data with fast, non-blocking access -type GovernanceStore struct { +// LocalGovernanceStore provides in-memory cache for governance data with fast, non-blocking access +type LocalGovernanceStore struct { // Core data maps using sync.Map for lock-free reads virtualKeys sync.Map // string -> *VirtualKey (VK value -> VirtualKey with preloaded relationships) teams sync.Map // string -> *Team (Team ID -> Team) customers sync.Map // string -> *Customer (Customer ID -> Customer) budgets sync.Map // string -> *Budget (Budget ID -> Budget) + rateLimits sync.Map // string -> *RateLimit (RateLimit ID -> RateLimit) // Config store for refresh operations configStore configstore.ConfigStore @@ -29,9 +30,50 @@ type GovernanceStore struct { logger schemas.Logger } -// NewGovernanceStore creates a new in-memory governance store -func NewGovernanceStore(ctx context.Context, logger schemas.Logger, configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig) (*GovernanceStore, error) { - store := &GovernanceStore{ +type GovernanceData struct { + VirtualKeys map[string]*configstoreTables.TableVirtualKey `json:"virtual_keys"` + Teams map[string]*configstoreTables.TableTeam `json:"teams"` + Customers map[string]*configstoreTables.TableCustomer `json:"customers"` + Budgets map[string]*configstoreTables.TableBudget `json:"budgets"` + RateLimits map[string]*configstoreTables.TableRateLimit `json:"rate_limits"` +} + +// GovernanceStore defines the interface for governance data access and policy evaluation. +// +// Error semantics contract: +// - CheckRateLimit and CheckBudget return a non-nil error to indicate a governance/policy +// violation (not an infrastructure/operational failure). +// - Callers must treat any non-nil error from these methods as an explicit denial/violation +// decision rather than a retryable infrastructure error. +// - This contract ensures consistent behavior across implementations (e.g., in-memory, +// DB-backed) and prevents retry loops on policy violations. +type GovernanceStore interface { + GetGovernanceData() *GovernanceData + GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) + CheckBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, baselines map[string]float64) error + CheckRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, model string, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) + UpdateBudgetUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error + UpdateRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error + ResetExpiredRateLimitsInMemory(ctx context.Context) []*configstoreTables.TableRateLimit + ResetExpiredBudgetsInMemory(ctx context.Context) []*configstoreTables.TableBudget + ResetExpiredRateLimits(ctx context.Context, resetRateLimits []*configstoreTables.TableRateLimit) error + ResetExpiredBudgets(ctx context.Context, resetBudgets []*configstoreTables.TableBudget) error + DumpRateLimits(ctx context.Context, tokenBaselines map[string]int64, requestBaselines map[string]int64) error + DumpBudgets(ctx context.Context, baselines map[string]float64) error + CreateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) + UpdateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, rateLimitTokensBaselines map[string]int64, rateLimitRequestsBaselines map[string]int64) + DeleteVirtualKeyInMemory(vkID string) + CreateTeamInMemory(team *configstoreTables.TableTeam) + UpdateTeamInMemory(team *configstoreTables.TableTeam, budgetBaselines map[string]float64) + DeleteTeamInMemory(teamID string) + CreateCustomerInMemory(customer *configstoreTables.TableCustomer) + UpdateCustomerInMemory(customer *configstoreTables.TableCustomer, budgetBaselines map[string]float64) + DeleteCustomerInMemory(customerID string) +} + +// NewLocalGovernanceStore creates a new in-memory governance store +func NewLocalGovernanceStore(ctx context.Context, logger schemas.Logger, configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig) (*LocalGovernanceStore, error) { + store := &LocalGovernanceStore{ configStore: configStore, logger: logger, } @@ -51,8 +93,63 @@ func NewGovernanceStore(ctx context.Context, logger schemas.Logger, configStore return store, nil } +func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { + virtualKeys := make(map[string]*configstoreTables.TableVirtualKey) + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + virtualKeys[key.(string)] = vk + return true // continue iteration + }) + teams := make(map[string]*configstoreTables.TableTeam) + gs.teams.Range(func(key, value interface{}) bool { + team, ok := value.(*configstoreTables.TableTeam) + if !ok || team == nil { + return true // continue + } + teams[key.(string)] = team + return true // continue iteration + }) + customers := make(map[string]*configstoreTables.TableCustomer) + gs.customers.Range(func(key, value interface{}) bool { + customer, ok := value.(*configstoreTables.TableCustomer) + if !ok || customer == nil { + return true // continue + } + customers[key.(string)] = customer + return true // continue iteration + }) + budgets := make(map[string]*configstoreTables.TableBudget) + gs.budgets.Range(func(key, value interface{}) bool { + budget, ok := value.(*configstoreTables.TableBudget) + if !ok || budget == nil { + return true // continue + } + budgets[key.(string)] = budget + return true // continue iteration + }) + rateLimits := make(map[string]*configstoreTables.TableRateLimit) + gs.rateLimits.Range(func(key, value interface{}) bool { + rateLimit, ok := value.(*configstoreTables.TableRateLimit) + if !ok || rateLimit == nil { + return true // continue + } + rateLimits[key.(string)] = rateLimit + return true // continue iteration + }) + return &GovernanceData{ + VirtualKeys: virtualKeys, + Teams: teams, + Customers: customers, + Budgets: budgets, + RateLimits: rateLimits, + } +} + // GetVirtualKey retrieves a virtual key by its value (lock-free) with all relationships preloaded -func (gs *GovernanceStore) GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) { +func (gs *LocalGovernanceStore) GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) { value, exists := gs.virtualKeys.Load(vkValue) if !exists || value == nil { return nil, false @@ -65,308 +162,582 @@ func (gs *GovernanceStore) GetVirtualKey(vkValue string) (*configstoreTables.Tab return vk, true } -// GetAllBudgets returns all budgets (for background reset operations) -func (gs *GovernanceStore) GetAllBudgets() map[string]*configstoreTables.TableBudget { - result := make(map[string]*configstoreTables.TableBudget) - gs.budgets.Range(func(key, value interface{}) bool { - // Type-safe conversion - keyStr, keyOk := key.(string) - budget, budgetOk := value.(*configstoreTables.TableBudget) - - if keyOk && budgetOk && budget != nil { - result[keyStr] = budget - } - return true // continue iteration - }) - return result -} - // CheckBudget performs budget checking using in-memory store data (lock-free for high performance) -func (gs *GovernanceStore) CheckBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) error { +func (gs *LocalGovernanceStore) CheckBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, baselines map[string]float64) error { if vk == nil { return fmt.Errorf("virtual key cannot be nil") } + // This is to prevent nil pointer dereference + if baselines == nil { + baselines = map[string]float64{} + } + // Use helper to collect budgets and their names (lock-free) - budgetsToCheck, budgetNames := gs.collectBudgetsFromHierarchy(ctx, vk, provider) + budgetsToCheck, budgetNames := gs.collectBudgetsFromHierarchy(vk, request.Provider) + + gs.logger.Debug("LocalStore CheckBudget: Received %d baselines from remote nodes", len(baselines)) + for budgetID, baseline := range baselines { + gs.logger.Debug(" - Baseline for budget %s: %.4f", budgetID, baseline) + } // Check each budget in hierarchy order using in-memory data for i, budget := range budgetsToCheck { // Check if budget needs reset (in-memory check) if budget.ResetDuration != "" { if duration, err := configstoreTables.ParseDuration(budget.ResetDuration); err == nil { - if time.Since(budget.LastReset).Round(time.Millisecond) >= duration { + if time.Since(budget.LastReset) >= duration { // Budget expired but hasn't been reset yet - treat as reset // Note: actual reset will happen in post-hook via AtomicBudgetUpdate + gs.logger.Debug("LocalStore CheckBudget: Budget %s (%s) expired, skipping check", budget.ID, budgetNames[i]) continue // Skip budget check for expired budgets } } } - // Check if current usage exceeds budget limit - if budget.CurrentUsage > budget.MaxLimit { - return fmt.Errorf("%s budget exceeded: %.4f > %.4f dollars", - budgetNames[i], budget.CurrentUsage, budget.MaxLimit) + baseline, exists := baselines[budget.ID] + if !exists { + baseline = 0 + } + + gs.logger.Debug("LocalStore CheckBudget: Checking %s budget %s: local=%.4f, remote=%.4f, total=%.4f, limit=%.4f", + budgetNames[i], budget.ID, budget.CurrentUsage, baseline, budget.CurrentUsage+baseline, budget.MaxLimit) + + // Check if current usage (local + remote baseline) exceeds budget limit + if budget.CurrentUsage+baseline >= budget.MaxLimit { + gs.logger.Debug("LocalStore CheckBudget: Budget %s EXCEEDED", budget.ID) + return fmt.Errorf("%s budget exceeded: %.4f >= %.4f dollars", + budgetNames[i], budget.CurrentUsage+baseline, budget.MaxLimit) } } + gs.logger.Debug("LocalStore CheckBudget: All budgets passed") + return nil } -// UpdateBudget performs atomic budget updates across the hierarchy (both in memory and in database) -func (gs *GovernanceStore) UpdateBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error { - if vk == nil { - return fmt.Errorf("virtual key cannot be nil") - } +// CheckRateLimit checks a single rate limit and returns evaluation result if violated (true if violated, false if not) +func (gs *LocalGovernanceStore) CheckRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, model string, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { + var violations []string - // Collect budget IDs using fast in-memory lookup instead of DB queries - budgetIDs := gs.collectBudgetIDsFromMemory(ctx, vk, provider) + // Collect rate limits and their names from the hierarchy + rateLimits, rateLimitNames := gs.collectRateLimitsFromHierarchy(vk, provider) - if gs.configStore == nil { - for _, budgetID := range budgetIDs { - // Update in-memory cache for next read (lock-free) - if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { - if cachedBudget, ok := cachedBudgetValue.(*configstoreTables.TableBudget); ok && cachedBudget != nil { - clone := *cachedBudget - clone.CurrentUsage += cost - gs.budgets.Store(budgetID, &clone) + // This is to prevent nil pointer dereference + if tokensBaselines == nil { + tokensBaselines = map[string]int64{} + } + if requestsBaselines == nil { + requestsBaselines = map[string]int64{} + } + + for i, rateLimit := range rateLimits { + // Determine token and request expiration independently + tokenExpired := false + requestExpired := false + + // Check if token reset duration is expired + if rateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + if time.Since(rateLimit.TokenLastReset) >= duration { + // Token rate limit expired but hasn't been reset yet - skip token checks + // Note: actual reset will happen in post-hook via AtomicRateLimitUpdate + tokenExpired = true } } } - return nil - } - - return gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { - // budgetIDs already collected from in-memory data - no need to duplicate - - // Update each budget atomically - for _, budgetID := range budgetIDs { - var budget configstoreTables.TableBudget - if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).First(&budget, "id = ?", budgetID).Error; err != nil { - return fmt.Errorf("failed to lock budget %s: %w", budgetID, err) + // Check if request reset duration is expired + if rateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if time.Since(rateLimit.RequestLastReset) >= duration { + // Request rate limit expired but hasn't been reset yet - skip request checks + // Note: actual reset will happen in post-hook via AtomicRateLimitUpdate + requestExpired = true + } } + } - // Check if budget needs reset - if err := gs.resetBudgetIfNeeded(ctx, tx, &budget); err != nil { - return fmt.Errorf("failed to reset budget: %w", err) + tokensBaseline, exists := tokensBaselines[rateLimit.ID] + if !exists { + tokensBaseline = 0 + } + requestsBaseline, exists := requestsBaselines[rateLimit.ID] + if !exists { + requestsBaseline = 0 + } + + // Token limits - check if total usage (local + remote baseline) exceeds limit + // Only check if token limit is not expired + if !tokenExpired && rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage+tokensBaseline >= *rateLimit.TokenMaxLimit { + duration := "unknown" + if rateLimit.TokenResetDuration != nil { + duration = *rateLimit.TokenResetDuration } + violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", + rateLimit.TokenCurrentUsage+tokensBaseline, *rateLimit.TokenMaxLimit, duration)) + } - // Update usage - budget.CurrentUsage += cost - if err := gs.configStore.UpdateBudget(ctx, &budget, tx); err != nil { - return fmt.Errorf("failed to save budget %s: %w", budgetID, err) + // Request limits - check if total usage (local + remote baseline) exceeds limit + // Only check if request limit is not expired + if !requestExpired && rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage+requestsBaseline >= *rateLimit.RequestMaxLimit { + duration := "unknown" + if rateLimit.RequestResetDuration != nil { + duration = *rateLimit.RequestResetDuration } + violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", + rateLimit.RequestCurrentUsage+requestsBaseline, *rateLimit.RequestMaxLimit, duration)) + } - // Update in-memory cache for next read (lock-free) - if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { - if cachedBudget, ok := cachedBudgetValue.(*configstoreTables.TableBudget); ok && cachedBudget != nil { - clone := *cachedBudget - clone.CurrentUsage += cost - clone.LastReset = budget.LastReset - gs.budgets.Store(budgetID, &clone) + if len(violations) > 0 { + // Determine specific violation type + decision := DecisionRateLimited // Default to general rate limited decision + if len(violations) == 1 { + if strings.Contains(violations[0], "token") { + decision = DecisionTokenLimited // More specific violation type + } else if strings.Contains(violations[0], "request") { + decision = DecisionRequestLimited // More specific violation type } } + msg := strings.Join(violations, "; ") + return decision, fmt.Errorf("rate limit violated for %s: %s", rateLimitNames[i], msg) } + } - return nil - }) + return DecisionAllow, nil // No rate limit violations } -// UpdateRateLimitUsage updates rate limit counters for both provider-level and VK-level rate limits (lock-free) -func (gs *GovernanceStore) UpdateRateLimitUsage(ctx context.Context, vkValue string, provider string, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { - if vkValue == "" { - return fmt.Errorf("virtual key value cannot be empty") +// UpdateBudgetUsageInMemory performs atomic budget updates across the hierarchy (both in memory and in database) +func (gs *LocalGovernanceStore) UpdateBudgetUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error { + if vk == nil { + return fmt.Errorf("virtual key cannot be nil") } - vkValue_, exists := gs.virtualKeys.Load(vkValue) - if !exists || vkValue_ == nil { - return fmt.Errorf("virtual key not found: %s", vkValue) + // Collect budget IDs using fast in-memory lookup instead of DB queries + budgetIDs := gs.collectBudgetIDsFromMemory(ctx, vk, provider) + now := time.Now() + for _, budgetID := range budgetIDs { + // Update in-memory cache for next read (lock-free) + if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { + if cachedBudget, ok := cachedBudgetValue.(*configstoreTables.TableBudget); ok && cachedBudget != nil { + // Clone FIRST to avoid race conditions + clone := *cachedBudget + oldUsage := clone.CurrentUsage + + // Check if budget needs reset (in-memory check) - operate on clone + if clone.ResetDuration != "" { + if duration, err := configstoreTables.ParseDuration(clone.ResetDuration); err == nil { + if now.Sub(clone.LastReset) >= duration { + clone.CurrentUsage = 0 + clone.LastReset = now + gs.logger.Debug("UpdateBudgetUsage: Budget %s was reset (expired, duration: %v)", budgetID, duration) + } + } + } + + // Update the clone + clone.CurrentUsage += cost + gs.budgets.Store(budgetID, &clone) + gs.logger.Debug("UpdateBudgetUsage: Updated budget %s: %.4f -> %.4f (added %.4f)", + budgetID, oldUsage, clone.CurrentUsage, cost) + } + } else { + gs.logger.Warn("UpdateBudgetUsage: Budget %s not found in local store", budgetID) + } } + return nil +} - vk, ok := vkValue_.(*configstoreTables.TableVirtualKey) - if !ok || vk == nil { - return fmt.Errorf("invalid virtual key type for: %s", vkValue) +// UpdateRateLimitUsageInMemory updates rate limit counters for both provider-level and VK-level rate limits (lock-free) +func (gs *LocalGovernanceStore) UpdateRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { + if vk == nil { + return fmt.Errorf("virtual key cannot be nil") } - var rateLimitsToUpdate []*configstoreTables.TableRateLimit + // Collect rate limit IDs using fast in-memory lookup instead of DB queries + rateLimitIDs := gs.collectRateLimitIDsFromMemory(vk, provider) + now := time.Now() - // First, update provider-level rate limits if they exist - if provider != "" && vk.ProviderConfigs != nil { - for _, pc := range vk.ProviderConfigs { - if pc.Provider == provider && pc.RateLimit != nil { - if gs.updateSingleRateLimit(pc.RateLimit, tokensUsed, shouldUpdateTokens, shouldUpdateRequests) { - rateLimitsToUpdate = append(rateLimitsToUpdate, pc.RateLimit) + for _, rateLimitID := range rateLimitIDs { + // Update in-memory cache for next read (lock-free) + if cachedRateLimitValue, exists := gs.rateLimits.Load(rateLimitID); exists && cachedRateLimitValue != nil { + if cachedRateLimit, ok := cachedRateLimitValue.(*configstoreTables.TableRateLimit); ok && cachedRateLimit != nil { + // Clone FIRST to avoid race conditions + clone := *cachedRateLimit + + // Check if rate limit needs reset (in-memory check) - operate on clone + if clone.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*clone.TokenResetDuration); err == nil { + if now.Sub(clone.TokenLastReset) >= duration { + clone.TokenCurrentUsage = 0 + clone.TokenLastReset = now + } + } + } + if clone.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*clone.RequestResetDuration); err == nil { + if now.Sub(clone.RequestLastReset) >= duration { + clone.RequestCurrentUsage = 0 + clone.RequestLastReset = now + } + } + } + + // Update the clone + if shouldUpdateTokens { + clone.TokenCurrentUsage += tokensUsed } - break + if shouldUpdateRequests { + clone.RequestCurrentUsage += 1 + } + gs.rateLimits.Store(rateLimitID, &clone) } } } + return nil +} - // Then, update VK-level rate limits if they exist - if vk.RateLimit != nil { - if gs.updateSingleRateLimit(vk.RateLimit, tokensUsed, shouldUpdateTokens, shouldUpdateRequests) { - rateLimitsToUpdate = append(rateLimitsToUpdate, vk.RateLimit) +// ResetExpiredBudgetsInMemory checks and resets budgets that have exceeded their reset duration (lock-free) +func (gs *LocalGovernanceStore) ResetExpiredBudgetsInMemory(ctx context.Context) []*configstoreTables.TableBudget { + now := time.Now() + var resetBudgets []*configstoreTables.TableBudget + + gs.budgets.Range(func(key, value interface{}) bool { + // Type-safe conversion + budget, ok := value.(*configstoreTables.TableBudget) + if !ok || budget == nil { + return true // continue } - } - // Save all updated rate limits to database - if len(rateLimitsToUpdate) > 0 && gs.configStore != nil { - if err := gs.configStore.UpdateRateLimits(ctx, rateLimitsToUpdate); err != nil { - return fmt.Errorf("failed to update rate limit usage: %w", err) + duration, err := configstoreTables.ParseDuration(budget.ResetDuration) + if err != nil { + gs.logger.Error("invalid budget reset duration %s: %v", budget.ResetDuration, err) + return true // continue } - } - return nil + if now.Sub(budget.LastReset) >= duration { + // Create a copy to avoid data race (sync.Map is concurrent-safe for reads/writes but not mutations) + copiedBudget := *budget + oldUsage := copiedBudget.CurrentUsage + copiedBudget.CurrentUsage = 0 + copiedBudget.LastReset = now + copiedBudget.LastDBUsage = 0 + + // Atomically replace the entry using the original key + gs.budgets.Store(key, &copiedBudget) + resetBudgets = append(resetBudgets, &copiedBudget) + + // Update all VKs, teams, customers, and provider configs that reference this budget + gs.updateBudgetReferences(&copiedBudget) + + gs.logger.Debug(fmt.Sprintf("Reset budget %s (was %.2f, reset to 0)", + copiedBudget.ID, oldUsage)) + } + return true // continue + }) + + return resetBudgets } -// updateSingleRateLimit updates a single rate limit's counters and returns true if any changes were made -func (gs *GovernanceStore) updateSingleRateLimit(rateLimit *configstoreTables.TableRateLimit, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) bool { +// ResetExpiredRateLimitsInMemory performs background reset of expired rate limits for both provider-level and VK-level (lock-free) +func (gs *LocalGovernanceStore) ResetExpiredRateLimitsInMemory(ctx context.Context) []*configstoreTables.TableRateLimit { now := time.Now() - updated := false + var resetRateLimits []*configstoreTables.TableRateLimit - // Check and reset token counter if needed - if rateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { - if now.Sub(rateLimit.TokenLastReset) >= duration { - rateLimit.TokenCurrentUsage = 0 - rateLimit.TokenLastReset = now - updated = true - } + gs.rateLimits.Range(func(key, value interface{}) bool { + // Type-safe conversion + rateLimit, ok := value.(*configstoreTables.TableRateLimit) + if !ok || rateLimit == nil { + return true // continue } - } - // Check and reset request counter if needed - if rateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { - if now.Sub(rateLimit.RequestLastReset) >= duration { - rateLimit.RequestCurrentUsage = 0 - rateLimit.RequestLastReset = now - updated = true + needsReset := false + // Check if token reset is needed + if rateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + if now.Sub(rateLimit.TokenLastReset) >= duration { + needsReset = true + } + } + } + // Check if request reset is needed + if rateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if now.Sub(rateLimit.RequestLastReset) >= duration { + needsReset = true + } } } - } - // Update usage counters based on flags - if shouldUpdateTokens && tokensUsed > 0 { - rateLimit.TokenCurrentUsage += tokensUsed - updated = true - } + if needsReset { + // Create a copy to avoid data race (sync.Map is concurrent-safe for reads/writes but not mutations) + copiedRateLimit := *rateLimit + + // Reset token limits if expired + if copiedRateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.TokenResetDuration); err == nil { + if now.Sub(copiedRateLimit.TokenLastReset) >= duration { + copiedRateLimit.TokenCurrentUsage = 0 + copiedRateLimit.TokenLastReset = now + copiedRateLimit.LastDBTokenUsage = 0 + } + } + } + // Reset request limits if expired + if copiedRateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.RequestResetDuration); err == nil { + if now.Sub(copiedRateLimit.RequestLastReset) >= duration { + copiedRateLimit.RequestCurrentUsage = 0 + copiedRateLimit.RequestLastReset = now + copiedRateLimit.LastDBRequestUsage = 0 + } + } + } - if shouldUpdateRequests { - rateLimit.RequestCurrentUsage += 1 - updated = true - } + // Atomically replace the entry using the original key + gs.rateLimits.Store(key, &copiedRateLimit) + resetRateLimits = append(resetRateLimits, &copiedRateLimit) - return updated -} + // Update all VKs and provider configs that reference this rate limit + gs.updateRateLimitReferences(&copiedRateLimit) + } + return true // continue + }) -// checkAndResetSingleRateLimit checks and resets a single rate limit's counters if expired -func (gs *GovernanceStore) checkAndResetSingleRateLimit(ctx context.Context, rateLimit *configstoreTables.TableRateLimit, now time.Time) bool { - updated := false + return resetRateLimits +} - // Check and reset token counter if needed - if rateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { - if now.Sub(rateLimit.TokenLastReset).Round(time.Millisecond) >= duration { - rateLimit.TokenCurrentUsage = 0 - rateLimit.TokenLastReset = now - updated = true +// ResetExpiredBudgets checks and resets budgets that have exceeded their reset duration in database +func (gs *LocalGovernanceStore) ResetExpiredBudgets(ctx context.Context, resetBudgets []*configstoreTables.TableBudget) error { + // Persist to database if any resets occurred using direct UPDATE to avoid overwriting config fields + if len(resetBudgets) > 0 && gs.configStore != nil { + if err := gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + for _, budget := range resetBudgets { + // Direct UPDATE only resets current_usage and last_reset + // This prevents overwriting max_limit or reset_duration that may have been changed by other nodes/requests + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableBudget{}). + Where("id = ?", budget.ID). + Updates(map[string]interface{}{ + "current_usage": budget.CurrentUsage, + "last_reset": budget.LastReset, + }) + + if result.Error != nil { + return fmt.Errorf("failed to reset budget %s: %w", budget.ID, result.Error) + } } + return nil + }); err != nil { + return fmt.Errorf("failed to persist budget resets to database: %w", err) } } - // Check and reset request counter if needed - if rateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { - if now.Sub(rateLimit.RequestLastReset).Round(time.Millisecond) >= duration { - rateLimit.RequestCurrentUsage = 0 - rateLimit.RequestLastReset = now - updated = true + return nil +} + +// ResetExpiredRateLimits performs background reset of expired rate limits for both provider-level and VK-level in database +func (gs *LocalGovernanceStore) ResetExpiredRateLimits(ctx context.Context, resetRateLimits []*configstoreTables.TableRateLimit) error { + if len(resetRateLimits) > 0 && gs.configStore != nil { + if err := gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + for _, rateLimit := range resetRateLimits { + // Build update map with only the fields that were reset + updates := make(map[string]interface{}) + + // Check which fields were reset by comparing with current values + if rateLimit.TokenCurrentUsage == 0 && rateLimit.TokenResetDuration != nil { + updates["token_current_usage"] = 0 + updates["token_last_reset"] = rateLimit.TokenLastReset + } + if rateLimit.RequestCurrentUsage == 0 && rateLimit.RequestResetDuration != nil { + updates["request_current_usage"] = 0 + updates["request_last_reset"] = rateLimit.RequestLastReset + } + + if len(updates) > 0 { + // Direct UPDATE only resets usage and last_reset fields + // This prevents overwriting max_limit or reset_duration that may have been changed by other nodes/requests + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableRateLimit{}). + Where("id = ?", rateLimit.ID). + Updates(updates) + + if result.Error != nil { + return fmt.Errorf("failed to reset rate limit %s: %w", rateLimit.ID, result.Error) + } + } } + return nil + }); err != nil { + return fmt.Errorf("failed to persist rate limit resets to database: %w", err) } } - - return updated + return nil } -// ResetExpiredRateLimits performs background reset of expired rate limits for both provider-level and VK-level (lock-free) -func (gs *GovernanceStore) ResetExpiredRateLimits(ctx context.Context) error { - now := time.Now() - var resetRateLimits []*configstoreTables.TableRateLimit +// DumpRateLimits dumps all rate limits to the database +func (gs *LocalGovernanceStore) DumpRateLimits(ctx context.Context, tokenBaselines map[string]int64, requestBaselines map[string]int64) error { + if gs.configStore == nil { + return nil + } + // This is to prevent nil pointer dereference + if tokenBaselines == nil { + tokenBaselines = map[string]int64{} + } + if requestBaselines == nil { + requestBaselines = map[string]int64{} + } + + // Collect unique rate limit IDs from virtual keys + rateLimitIDs := make(map[string]bool) gs.virtualKeys.Range(func(key, value interface{}) bool { - // Type-safe conversion vk, ok := value.(*configstoreTables.TableVirtualKey) if !ok || vk == nil { return true // continue } - - // Check provider-level rate limits + if vk.RateLimitID != nil { + rateLimitIDs[*vk.RateLimitID] = true + } if vk.ProviderConfigs != nil { for _, pc := range vk.ProviderConfigs { - if pc.RateLimit != nil { - if gs.checkAndResetSingleRateLimit(ctx, pc.RateLimit, now) { - resetRateLimits = append(resetRateLimits, pc.RateLimit) - } + if pc.RateLimitID != nil { + rateLimitIDs[*pc.RateLimitID] = true } } } - - // Check VK-level rate limits - if vk.RateLimit != nil { - if gs.checkAndResetSingleRateLimit(ctx, vk.RateLimit, now) { - resetRateLimits = append(resetRateLimits, vk.RateLimit) - } - } - return true // continue }) - // Persist reset rate limits to database - if len(resetRateLimits) > 0 && gs.configStore != nil { - if err := gs.configStore.UpdateRateLimits(ctx, resetRateLimits); err != nil { - return fmt.Errorf("failed to persist rate limit resets to database: %w", err) + // Prepare rate limit usage updates with baselines + type rateLimitUpdate struct { + ID string + TokenCurrentUsage int64 + RequestCurrentUsage int64 + } + var rateLimitUpdates []rateLimitUpdate + for rateLimitID := range rateLimitIDs { + if rateLimitValue, exists := gs.rateLimits.Load(rateLimitID); exists && rateLimitValue != nil { + if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { + update := rateLimitUpdate{ + ID: rateLimit.ID, + TokenCurrentUsage: rateLimit.TokenCurrentUsage, + RequestCurrentUsage: rateLimit.RequestCurrentUsage, + } + if tokenBaseline, exists := tokenBaselines[rateLimit.ID]; exists { + update.TokenCurrentUsage += tokenBaseline + } + if requestBaseline, exists := requestBaselines[rateLimit.ID]; exists { + update.RequestCurrentUsage += requestBaseline + } + rateLimitUpdates = append(rateLimitUpdates, update) + } } } + // Save all updated rate limits to database using direct UPDATE to avoid overwriting config fields + if len(rateLimitUpdates) > 0 && gs.configStore != nil { + if err := gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + for _, update := range rateLimitUpdates { + // Direct UPDATE only updates usage fields + // This prevents overwriting max_limit or reset_duration that may have been changed by other nodes/requests + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableRateLimit{}). + Where("id = ?", update.ID). + Updates(map[string]interface{}{ + "token_current_usage": update.TokenCurrentUsage, + "request_current_usage": update.RequestCurrentUsage, + }) + + if result.Error != nil { + return fmt.Errorf("failed to dump rate limit %s: %w", update.ID, result.Error) + } + } + return nil + }); err != nil { + // Check if error is a deadlock (SQLSTATE 40P01 for PostgreSQL, 1213 for MySQL) + errStr := err.Error() + isDeadlock := strings.Contains(errStr, "deadlock") || + strings.Contains(errStr, "40P01") || + strings.Contains(errStr, "1213") + + if isDeadlock { + // Deadlock means another node is updating the same rows - this is fine! + // Our usage data will be synced via gossip and written in the next dump cycle + gs.logger.Debug("Rate limit dump encountered deadlock (another node is updating) - will retry next cycle") + return nil // Not a real error in multi-node setup + } + return fmt.Errorf("failed to dump rate limits to database: %w", err) + } + } return nil } -// ResetExpiredBudgets checks and resets budgets that have exceeded their reset duration (lock-free) -func (gs *GovernanceStore) ResetExpiredBudgets(ctx context.Context) error { - now := time.Now() - var resetBudgets []*configstoreTables.TableBudget +// DumpBudgets dumps all budgets to the database +func (gs *LocalGovernanceStore) DumpBudgets(ctx context.Context, baselines map[string]float64) error { + if gs.configStore == nil { + return nil + } + + // This is to prevent nil pointer dereference + if baselines == nil { + baselines = map[string]float64{} + } + + budgets := make(map[string]*configstoreTables.TableBudget) gs.budgets.Range(func(key, value interface{}) bool { // Type-safe conversion - budget, ok := value.(*configstoreTables.TableBudget) - if !ok || budget == nil { - return true // continue - } + keyStr, keyOk := key.(string) + budget, budgetOk := value.(*configstoreTables.TableBudget) - duration, err := configstoreTables.ParseDuration(budget.ResetDuration) - if err != nil { - gs.logger.Error("invalid budget reset duration %s: %w", budget.ResetDuration, err) - return true // continue + if keyOk && budgetOk && budget != nil { + budgets[keyStr] = budget // Store budget by ID } + return true // continue iteration + }) - if now.Sub(budget.LastReset) >= duration { - oldUsage := budget.CurrentUsage - budget.CurrentUsage = 0 - budget.LastReset = now - resetBudgets = append(resetBudgets, budget) + if len(budgets) > 0 && gs.configStore != nil { + if err := gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + // Update each budget atomically using direct UPDATE to avoid deadlocks + // (SELECT + Save pattern causes deadlocks when multiple instances run concurrently) + for _, inMemoryBudget := range budgets { + // Calculate the new usage value + newUsage := inMemoryBudget.CurrentUsage + if baseline, exists := baselines[inMemoryBudget.ID]; exists { + newUsage += baseline + } - gs.logger.Debug(fmt.Sprintf("Reset budget %s (was %.2f, reset to 0)", - budget.ID, oldUsage)) - } - return true // continue - }) + // Direct UPDATE avoids read-then-write lock escalation that causes deadlocks + // Use Session with SkipHooks to avoid triggering BeforeSave hook validation + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableBudget{}). + Where("id = ?", inMemoryBudget.ID). + Update("current_usage", newUsage) - // Persist to database if any resets occurred - if len(resetBudgets) > 0 && gs.configStore != nil { - if err := gs.configStore.UpdateBudgets(ctx, resetBudgets); err != nil { - return fmt.Errorf("failed to persist budget resets to database: %w", err) + if result.Error != nil { + return fmt.Errorf("failed to update budget %s: %w", inMemoryBudget.ID, result.Error) + } + } + return nil + }); err != nil { + // Check if error is a deadlock (SQLSTATE 40P01 for PostgreSQL, 1213 for MySQL) + errStr := err.Error() + isDeadlock := strings.Contains(errStr, "deadlock") || + strings.Contains(errStr, "40P01") || + strings.Contains(errStr, "1213") + + if isDeadlock { + // Deadlock means another node is updating the same rows - this is fine! + // Our usage data will be synced via gossip and written in the next dump cycle + gs.logger.Debug("Budget dump encountered deadlock (another node is updating) - will retry next cycle") + return nil // Not a real error in multi-node setup + } + return fmt.Errorf("failed to dump budgets to database: %w", err) } } @@ -376,7 +747,7 @@ func (gs *GovernanceStore) ResetExpiredBudgets(ctx context.Context) error { // DATABASE METHODS // loadFromDatabase loads all governance data from the database into memory -func (gs *GovernanceStore) loadFromDatabase(ctx context.Context) error { +func (gs *LocalGovernanceStore) loadFromDatabase(ctx context.Context) error { // Load customers with their budgets customers, err := gs.configStore.GetCustomers(ctx) if err != nil { @@ -401,14 +772,20 @@ func (gs *GovernanceStore) loadFromDatabase(ctx context.Context) error { return fmt.Errorf("failed to load budgets: %w", err) } + // Load rate limits + rateLimits, err := gs.configStore.GetRateLimits(ctx) + if err != nil { + return fmt.Errorf("failed to load rate limits: %w", err) + } + // Rebuild in-memory structures (lock-free) - gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets) + gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets, rateLimits) return nil } // loadFromConfigMemory loads all governance data from the config's memory into store's memory -func (gs *GovernanceStore) loadFromConfigMemory(ctx context.Context, config *configstore.GovernanceConfig) error { +func (gs *LocalGovernanceStore) loadFromConfigMemory(ctx context.Context, config *configstore.GovernanceConfig) error { if config == nil { return fmt.Errorf("governance config is nil") } @@ -456,22 +833,50 @@ func (gs *GovernanceStore) loadFromConfigMemory(ctx context.Context, config *con } } + // Populate provider config relationships with budgets and rate limits + if vk.ProviderConfigs != nil { + for j := range vk.ProviderConfigs { + pc := &vk.ProviderConfigs[j] + + // Populate budget + if pc.BudgetID != nil { + for k := range budgets { + if budgets[k].ID == *pc.BudgetID { + pc.Budget = &budgets[k] + break + } + } + } + + // Populate rate limit + if pc.RateLimitID != nil { + for k := range rateLimits { + if rateLimits[k].ID == *pc.RateLimitID { + pc.RateLimit = &rateLimits[k] + break + } + } + } + } + } + virtualKeys[i] = *vk } // Rebuild in-memory structures (lock-free) - gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets) + gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets, rateLimits) return nil } // rebuildInMemoryStructures rebuilds all in-memory data structures (lock-free) -func (gs *GovernanceStore) rebuildInMemoryStructures(ctx context.Context, customers []configstoreTables.TableCustomer, teams []configstoreTables.TableTeam, virtualKeys []configstoreTables.TableVirtualKey, budgets []configstoreTables.TableBudget) { +func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, customers []configstoreTables.TableCustomer, teams []configstoreTables.TableTeam, virtualKeys []configstoreTables.TableVirtualKey, budgets []configstoreTables.TableBudget, rateLimits []configstoreTables.TableRateLimit) { // Clear existing data by creating new sync.Maps gs.virtualKeys = sync.Map{} gs.teams = sync.Map{} gs.customers = sync.Map{} gs.budgets = sync.Map{} + gs.rateLimits = sync.Map{} // Build customers map for i := range customers { @@ -491,6 +896,12 @@ func (gs *GovernanceStore) rebuildInMemoryStructures(ctx context.Context, custom gs.budgets.Store(budget.ID, budget) } + // Build rate limits map + for i := range rateLimits { + rateLimit := &rateLimits[i] + gs.rateLimits.Store(rateLimit.ID, rateLimit) + } + // Build virtual keys map and track active VKs for i := range virtualKeys { vk := &virtualKeys[i] @@ -500,8 +911,40 @@ func (gs *GovernanceStore) rebuildInMemoryStructures(ctx context.Context, custom // UTILITY FUNCTIONS +// collectRateLimitsFromHierarchy collects rate limits and their metadata from the hierarchy (Provider Configs → VK) +func (gs *LocalGovernanceStore) collectRateLimitsFromHierarchy(vk *configstoreTables.TableVirtualKey, requestedProvider schemas.ModelProvider) ([]*configstoreTables.TableRateLimit, []string) { + if vk == nil { + return nil, nil + } + + var rateLimits []*configstoreTables.TableRateLimit + var rateLimitNames []string + + for _, pc := range vk.ProviderConfigs { + if pc.RateLimitID != nil && pc.Provider == string(requestedProvider) { + if rateLimitValue, exists := gs.rateLimits.Load(*pc.RateLimitID); exists && rateLimitValue != nil { + if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { + rateLimits = append(rateLimits, rateLimit) + rateLimitNames = append(rateLimitNames, pc.Provider) + } + } + } + } + + if vk.RateLimitID != nil { + if rateLimitValue, exists := gs.rateLimits.Load(*vk.RateLimitID); exists && rateLimitValue != nil { + if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { + rateLimits = append(rateLimits, rateLimit) + rateLimitNames = append(rateLimitNames, "VK") + } + } + } + + return rateLimits, rateLimitNames +} + // collectBudgetsFromHierarchy collects budgets and their metadata from the hierarchy (Provider Configs → VK → Team → Customer) -func (gs *GovernanceStore) collectBudgetsFromHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) ([]*configstoreTables.TableBudget, []string) { +func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTables.TableVirtualKey, requestedProvider schemas.ModelProvider) ([]*configstoreTables.TableBudget, []string) { if vk == nil { return nil, nil } @@ -511,7 +954,7 @@ func (gs *GovernanceStore) collectBudgetsFromHierarchy(ctx context.Context, vk * // Collect all budgets in hierarchy order using lock-free sync.Map access (Provider Configs → VK → Team → Customer) for _, pc := range vk.ProviderConfigs { - if pc.BudgetID != nil && pc.Provider == string(provider) { + if pc.BudgetID != nil && pc.Provider == string(requestedProvider) { if budgetValue, exists := gs.budgets.Load(*pc.BudgetID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { budgets = append(budgets, budget) @@ -580,8 +1023,8 @@ func (gs *GovernanceStore) collectBudgetsFromHierarchy(ctx context.Context, vk * } // collectBudgetIDsFromMemory collects budget IDs from in-memory store data (lock-free) -func (gs *GovernanceStore) collectBudgetIDsFromMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) []string { - budgets, _ := gs.collectBudgetsFromHierarchy(ctx, vk, provider) +func (gs *LocalGovernanceStore) collectBudgetIDsFromMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) []string { + budgets, _ := gs.collectBudgetsFromHierarchy(vk, provider) budgetIDs := make([]string, len(budgets)) for i, budget := range budgets { @@ -591,49 +1034,195 @@ func (gs *GovernanceStore) collectBudgetIDsFromMemory(ctx context.Context, vk *c return budgetIDs } -// resetBudgetIfNeeded checks and resets budget within a transaction -func (gs *GovernanceStore) resetBudgetIfNeeded(ctx context.Context, tx *gorm.DB, budget *configstoreTables.TableBudget) error { - duration, err := configstoreTables.ParseDuration(budget.ResetDuration) - if err != nil { - return fmt.Errorf("invalid reset duration %s: %w", budget.ResetDuration, err) - } - - now := time.Now() - if now.Sub(budget.LastReset) >= duration { - budget.CurrentUsage = 0 - budget.LastReset = now +// collectRateLimitIDsFromMemory collects rate limit IDs from in-memory store data (lock-free) +func (gs *LocalGovernanceStore) collectRateLimitIDsFromMemory(vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) []string { + rateLimits, _ := gs.collectRateLimitsFromHierarchy(vk, provider) - if gs.configStore != nil { - // Save reset to database - if err := gs.configStore.UpdateBudget(ctx, budget, tx); err != nil { - return fmt.Errorf("failed to save budget reset: %w", err) - } - } + rateLimitIDs := make([]string, len(rateLimits)) + for i, rateLimit := range rateLimits { + rateLimitIDs[i] = rateLimit.ID } - return nil + return rateLimitIDs } // PUBLIC API METHODS // CreateVirtualKeyInMemory adds a new virtual key to the in-memory store (lock-free) -func (gs *GovernanceStore) CreateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) { // with rateLimit preloaded +func (gs *LocalGovernanceStore) CreateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) { if vk == nil { return // Nothing to create } + + // Create associated budget if exists + if vk.Budget != nil { + gs.budgets.Store(vk.Budget.ID, vk.Budget) + } + + // Create associated rate limit if exists + if vk.RateLimit != nil { + gs.rateLimits.Store(vk.RateLimit.ID, vk.RateLimit) + } + + // Create provider config budgets and rate limits if they exist + if vk.ProviderConfigs != nil { + for _, pc := range vk.ProviderConfigs { + if pc.Budget != nil { + gs.budgets.Store(pc.Budget.ID, pc.Budget) + } + if pc.RateLimit != nil { + gs.rateLimits.Store(pc.RateLimit.ID, pc.RateLimit) + } + } + } + gs.virtualKeys.Store(vk.Value, vk) } // UpdateVirtualKeyInMemory updates an existing virtual key in the in-memory store (lock-free) -func (gs *GovernanceStore) UpdateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) { // with rateLimit preloaded +func (gs *LocalGovernanceStore) UpdateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, rateLimitTokensBaselines map[string]int64, rateLimitRequestsBaselines map[string]int64) { if vk == nil { return // Nothing to update } - gs.virtualKeys.Store(vk.Value, vk) + if budgetBaselines == nil { + budgetBaselines = make(map[string]float64) + } + if rateLimitTokensBaselines == nil { + rateLimitTokensBaselines = make(map[string]int64) + } + if rateLimitRequestsBaselines == nil { + rateLimitRequestsBaselines = make(map[string]int64) + } + // Do not update the current usage of the rate limit, as it will be updated by the usage tracker. + // But update if max limit or reset duration changes. + if existingVKValue, exists := gs.virtualKeys.Load(vk.Value); exists && existingVKValue != nil { + existingVK, ok := existingVKValue.(*configstoreTables.TableVirtualKey) + if !ok || existingVK == nil { + return // Nothing to update + } + // Create clone to avoid modifying the original + clone := *vk + // Update Budget using checkAndUpdateBudget logic (preserve usage unless currentUsage+baseline > newMaxLimit) + if clone.Budget != nil { + // Get existing budget from gs.budgets (NOT from VK.Budget which may be stale) + var existingBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(clone.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingBudget = eb + } + } + budgetBaseline, exists := budgetBaselines[clone.Budget.ID] + if !exists { + budgetBaseline = 0.0 + } + clone.Budget = checkAndUpdateBudget(clone.Budget, existingBudget, budgetBaseline) + // Update the budget in the main budgets sync.Map + if clone.Budget != nil { + gs.budgets.Store(clone.Budget.ID, clone.Budget) + } + } else if existingVK.Budget != nil { + // Budget was removed from the virtual key, delete it from memory + gs.budgets.Delete(existingVK.Budget.ID) + } + if clone.RateLimit != nil { + // Get existing rate limit from gs.rateLimits (NOT from VK.RateLimit which may be stale) + var existingRateLimit *configstoreTables.TableRateLimit + if existingRateLimitValue, exists := gs.rateLimits.Load(clone.RateLimit.ID); exists && existingRateLimitValue != nil { + if erl, ok := existingRateLimitValue.(*configstoreTables.TableRateLimit); ok && erl != nil { + existingRateLimit = erl + } + } + tokenBaseline, exists := rateLimitTokensBaselines[clone.RateLimit.ID] + if !exists { + tokenBaseline = 0 + } + requestBaseline, exists := rateLimitRequestsBaselines[clone.RateLimit.ID] + if !exists { + requestBaseline = 0 + } + clone.RateLimit = checkAndUpdateRateLimit(clone.RateLimit, existingRateLimit, tokenBaseline, requestBaseline) + // Update the rate limit in the main rateLimits sync.Map + if clone.RateLimit != nil { + gs.rateLimits.Store(clone.RateLimit.ID, clone.RateLimit) + } + } else if existingVK.RateLimit != nil { + // Rate limit was removed from the virtual key, delete it from memory + gs.rateLimits.Delete(existingVK.RateLimit.ID) + } + if clone.ProviderConfigs != nil { + // Create a map of existing provider configs by ID for fast lookup + existingProviderConfigs := make(map[uint]configstoreTables.TableVirtualKeyProviderConfig) + if existingVK.ProviderConfigs != nil { + for _, existingPC := range existingVK.ProviderConfigs { + existingProviderConfigs[existingPC.ID] = existingPC + } + } + + // Process each new/updated provider config + for i, pc := range clone.ProviderConfigs { + if pc.RateLimit != nil { + // Get existing rate limit from gs.rateLimits (NOT from provider config which may be stale) + var existingProviderRateLimit *configstoreTables.TableRateLimit + if existingRateLimitValue, exists := gs.rateLimits.Load(pc.RateLimit.ID); exists && existingRateLimitValue != nil { + if erl, ok := existingRateLimitValue.(*configstoreTables.TableRateLimit); ok && erl != nil { + existingProviderRateLimit = erl + } + } + tokenBaseline, exists := rateLimitTokensBaselines[pc.RateLimit.ID] + if !exists { + tokenBaseline = 0 + } + requestBaseline, exists := rateLimitRequestsBaselines[pc.RateLimit.ID] + if !exists { + requestBaseline = 0 + } + clone.ProviderConfigs[i].RateLimit = checkAndUpdateRateLimit(pc.RateLimit, existingProviderRateLimit, tokenBaseline, requestBaseline) + // Also update the rate limit in the main rateLimits sync.Map + if clone.ProviderConfigs[i].RateLimit != nil { + gs.rateLimits.Store(clone.ProviderConfigs[i].RateLimit.ID, clone.ProviderConfigs[i].RateLimit) + } + } else { + // Rate limit was removed from provider config, delete it from memory if it existed + if existingPC, exists := existingProviderConfigs[pc.ID]; exists && existingPC.RateLimit != nil { + gs.rateLimits.Delete(existingPC.RateLimit.ID) + clone.ProviderConfigs[i].RateLimit = nil + } + } + // Update Budget for provider config (preserve usage unless currentUsage+baseline > newMaxLimit) + if pc.Budget != nil { + // Get existing budget from gs.budgets (NOT from provider config which may be stale) + var existingProviderBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(pc.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingProviderBudget = eb + } + } + budgetBaseline, exists := budgetBaselines[pc.Budget.ID] + if !exists { + budgetBaseline = 0.0 + } + clone.ProviderConfigs[i].Budget = checkAndUpdateBudget(pc.Budget, existingProviderBudget, budgetBaseline) + // Also update the budget in the main budgets sync.Map + if clone.ProviderConfigs[i].Budget != nil { + gs.budgets.Store(clone.ProviderConfigs[i].Budget.ID, clone.ProviderConfigs[i].Budget) + } + } else { + // Budget was removed from provider config, delete it from memory if it existed + if existingPC, exists := existingProviderConfigs[pc.ID]; exists && existingPC.Budget != nil { + gs.budgets.Delete(existingPC.Budget.ID) + clone.ProviderConfigs[i].Budget = nil + } + } + } + } + gs.virtualKeys.Store(vk.Value, &clone) + } else { + gs.CreateVirtualKeyInMemory(vk) + } } // DeleteVirtualKeyInMemory removes a virtual key from the in-memory store -func (gs *GovernanceStore) DeleteVirtualKeyInMemory(vkID string) { +func (gs *LocalGovernanceStore) DeleteVirtualKeyInMemory(vkID string) { if vkID == "" { return // Nothing to delete } @@ -647,6 +1236,28 @@ func (gs *GovernanceStore) DeleteVirtualKeyInMemory(vkID string) { } if vk.ID == vkID { + // Delete associated budget if exists + if vk.BudgetID != nil { + gs.budgets.Delete(*vk.BudgetID) + } + + // Delete associated rate limit if exists + if vk.RateLimitID != nil { + gs.rateLimits.Delete(*vk.RateLimitID) + } + + // Delete provider config budgets and rate limits + if vk.ProviderConfigs != nil { + for _, pc := range vk.ProviderConfigs { + if pc.BudgetID != nil { + gs.budgets.Delete(*pc.BudgetID) + } + if pc.RateLimitID != nil { + gs.rateLimits.Delete(*pc.RateLimitID) + } + } + } + gs.virtualKeys.Delete(key) return false // stop iteration } @@ -655,74 +1266,403 @@ func (gs *GovernanceStore) DeleteVirtualKeyInMemory(vkID string) { } // CreateTeamInMemory adds a new team to the in-memory store (lock-free) -func (gs *GovernanceStore) CreateTeamInMemory(team *configstoreTables.TableTeam) { +func (gs *LocalGovernanceStore) CreateTeamInMemory(team *configstoreTables.TableTeam) { if team == nil { return // Nothing to create } + + // Create associated budget if exists + if team.Budget != nil { + gs.budgets.Store(team.Budget.ID, team.Budget) + } + gs.teams.Store(team.ID, team) } // UpdateTeamInMemory updates an existing team in the in-memory store (lock-free) -func (gs *GovernanceStore) UpdateTeamInMemory(team *configstoreTables.TableTeam) { +func (gs *LocalGovernanceStore) UpdateTeamInMemory(team *configstoreTables.TableTeam, budgetBaselines map[string]float64) { if team == nil { return // Nothing to update } - gs.teams.Store(team.ID, team) + if budgetBaselines == nil { + budgetBaselines = make(map[string]float64) + } + + // Check if there's an existing team to get current budget state + if existingTeamValue, exists := gs.teams.Load(team.ID); exists && existingTeamValue != nil { + existingTeam, ok := existingTeamValue.(*configstoreTables.TableTeam) + if !ok || existingTeam == nil { + return // Nothing to update + } + // Create clone to avoid modifying the original + clone := *team + + // Handle budget updates with consistent logic + if clone.Budget != nil { + // Get existing budget from gs.budgets (NOT from Team.Budget which may be stale) + var existingBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(clone.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingBudget = eb + } + } + budgetBaseline, exists := budgetBaselines[clone.Budget.ID] + if !exists { + budgetBaseline = 0.0 + } + clone.Budget = checkAndUpdateBudget(clone.Budget, existingBudget, budgetBaseline) + // Update the budget in the main budgets sync.Map + if clone.Budget != nil { + gs.budgets.Store(clone.Budget.ID, clone.Budget) + } + } else if existingTeam.Budget != nil { + // Budget was removed from the team, delete it from memory + gs.budgets.Delete(existingTeam.Budget.ID) + } + + gs.teams.Store(team.ID, &clone) + } else { + gs.CreateTeamInMemory(team) + } } // DeleteTeamInMemory removes a team from the in-memory store (lock-free) -func (gs *GovernanceStore) DeleteTeamInMemory(teamID string) { +func (gs *LocalGovernanceStore) DeleteTeamInMemory(teamID string) { if teamID == "" { return // Nothing to delete } + + // Get team to check for associated budget + if teamValue, exists := gs.teams.Load(teamID); exists && teamValue != nil { + if team, ok := teamValue.(*configstoreTables.TableTeam); ok && team != nil { + // Delete associated budget if exists + if team.BudgetID != nil { + gs.budgets.Delete(*team.BudgetID) + } + } + } + + // Set team_id to null for all virtual keys associated with the team + // Iterate through all VKs since team.VirtualKeys may not be populated + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + if vk.TeamID != nil && *vk.TeamID == teamID { + clone := *vk + clone.TeamID = nil + clone.Team = nil + gs.virtualKeys.Store(key, &clone) + } + return true // continue iteration + }) + gs.teams.Delete(teamID) } // CreateCustomerInMemory adds a new customer to the in-memory store (lock-free) -func (gs *GovernanceStore) CreateCustomerInMemory(customer *configstoreTables.TableCustomer) { +func (gs *LocalGovernanceStore) CreateCustomerInMemory(customer *configstoreTables.TableCustomer) { if customer == nil { return // Nothing to create } + + // Create associated budget if exists + if customer.Budget != nil { + gs.budgets.Store(customer.Budget.ID, customer.Budget) + } + gs.customers.Store(customer.ID, customer) } // UpdateCustomerInMemory updates an existing customer in the in-memory store (lock-free) -func (gs *GovernanceStore) UpdateCustomerInMemory(customer *configstoreTables.TableCustomer) { +func (gs *LocalGovernanceStore) UpdateCustomerInMemory(customer *configstoreTables.TableCustomer, budgetBaselines map[string]float64) { if customer == nil { return // Nothing to update } - gs.customers.Store(customer.ID, customer) + if budgetBaselines == nil { + budgetBaselines = make(map[string]float64) + } + + // Check if there's an existing customer to get current budget state + if existingCustomerValue, exists := gs.customers.Load(customer.ID); exists && existingCustomerValue != nil { + existingCustomer, ok := existingCustomerValue.(*configstoreTables.TableCustomer) + if !ok || existingCustomer == nil { + return // Nothing to update + } + // Create clone to avoid modifying the original + clone := *customer + + // Handle budget updates with consistent logic + if clone.Budget != nil { + // Get existing budget from gs.budgets (NOT from Customer.Budget which may be stale) + var existingBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(clone.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingBudget = eb + } + } + budgetBaseline, exists := budgetBaselines[clone.Budget.ID] + if !exists { + budgetBaseline = 0.0 + } + clone.Budget = checkAndUpdateBudget(clone.Budget, existingBudget, budgetBaseline) + // Update the budget in the main budgets sync.Map + if clone.Budget != nil { + gs.budgets.Store(clone.Budget.ID, clone.Budget) + } + } else if existingCustomer.Budget != nil { + // Budget was removed from the customer, delete it from memory + gs.budgets.Delete(existingCustomer.Budget.ID) + } + + gs.customers.Store(customer.ID, &clone) + } else { + gs.CreateCustomerInMemory(customer) + } } // DeleteCustomerInMemory removes a customer from the in-memory store (lock-free) -func (gs *GovernanceStore) DeleteCustomerInMemory(customerID string) { +func (gs *LocalGovernanceStore) DeleteCustomerInMemory(customerID string) { if customerID == "" { return // Nothing to delete } + + // Get customer to check for associated budget + if customerValue, exists := gs.customers.Load(customerID); exists && customerValue != nil { + if customer, ok := customerValue.(*configstoreTables.TableCustomer); ok && customer != nil { + // Delete associated budget if exists + if customer.BudgetID != nil { + gs.budgets.Delete(*customer.BudgetID) + } + } + } + + // Set customer_id to null for all virtual keys associated with the customer + // Iterate through all VKs since customer.VirtualKeys may not be populated + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + if vk.CustomerID != nil && *vk.CustomerID == customerID { + clone := *vk + clone.CustomerID = nil + clone.Customer = nil + gs.virtualKeys.Store(key, &clone) + } + return true // continue iteration + }) + + // Set customer_id to null for all teams associated with the customer + // Iterate through all teams since customer.Teams may not be populated + gs.teams.Range(func(key, value interface{}) bool { + team, ok := value.(*configstoreTables.TableTeam) + if !ok || team == nil { + return true // continue + } + if team.CustomerID != nil && *team.CustomerID == customerID { + clone := *team + clone.CustomerID = nil + clone.Customer = nil + gs.teams.Store(key, &clone) + } + return true // continue iteration + }) + gs.customers.Delete(customerID) } -// CreateBudgetInMemory adds a new budget to the in-memory store (lock-free) -func (gs *GovernanceStore) CreateBudgetInMemory(budget *configstoreTables.TableBudget) { - if budget == nil { - return // Nothing to create - } - gs.budgets.Store(budget.ID, budget) +// Helper functions + +// updateBudgetReferences updates all VKs, teams, customers, and provider configs that reference a reset budget +func (gs *LocalGovernanceStore) updateBudgetReferences(resetBudget *configstoreTables.TableBudget) { + budgetID := resetBudget.ID + // Update VKs that reference this budget + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + needsUpdate := false + clone := *vk + + // Check VK-level budget + if vk.BudgetID != nil && *vk.BudgetID == budgetID { + clone.Budget = resetBudget + needsUpdate = true + } + + // Check provider config budgets + if vk.ProviderConfigs != nil { + for i, pc := range clone.ProviderConfigs { + if pc.BudgetID != nil && *pc.BudgetID == budgetID { + clone.ProviderConfigs[i].Budget = resetBudget + needsUpdate = true + } + } + } + + if needsUpdate { + gs.virtualKeys.Store(key, &clone) + } + return true // continue + }) + + // Update teams that reference this budget + gs.teams.Range(func(key, value interface{}) bool { + team, ok := value.(*configstoreTables.TableTeam) + if !ok || team == nil { + return true // continue + } + if team.BudgetID != nil && *team.BudgetID == budgetID { + clone := *team + clone.Budget = resetBudget + gs.teams.Store(key, &clone) + } + return true // continue + }) + + // Update customers that reference this budget + gs.customers.Range(func(key, value interface{}) bool { + customer, ok := value.(*configstoreTables.TableCustomer) + if !ok || customer == nil { + return true // continue + } + if customer.BudgetID != nil && *customer.BudgetID == budgetID { + clone := *customer + clone.Budget = resetBudget + gs.customers.Store(key, &clone) + } + return true // continue + }) +} + +// updateRateLimitReferences updates all VKs and provider configs that reference a reset rate limit +func (gs *LocalGovernanceStore) updateRateLimitReferences(resetRateLimit *configstoreTables.TableRateLimit) { + rateLimitID := resetRateLimit.ID + // Update VKs that reference this rate limit + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + needsUpdate := false + clone := *vk + + // Check VK-level rate limit + if vk.RateLimitID != nil && *vk.RateLimitID == rateLimitID { + clone.RateLimit = resetRateLimit + needsUpdate = true + } + + // Check provider config rate limits + if vk.ProviderConfigs != nil { + for i, pc := range clone.ProviderConfigs { + if pc.RateLimitID != nil && *pc.RateLimitID == rateLimitID { + clone.ProviderConfigs[i].RateLimit = resetRateLimit + needsUpdate = true + } + } + } + + if needsUpdate { + gs.virtualKeys.Store(key, &clone) + } + return true // continue + }) } -// UpdateBudgetInMemory updates a specific budget in the in-memory cache (lock-free) -func (gs *GovernanceStore) UpdateBudgetInMemory(budget *configstoreTables.TableBudget) error { - if budget == nil { - return fmt.Errorf("budget cannot be nil") +// checkAndUpdateBudget checks and updates a budget with usage reset logic +// If currentUsage+baseline >= newMaxLimit, reset usage to 0 +// Otherwise preserve existing usage and accept reset duration and max limit changes +func checkAndUpdateBudget(budgetToUpdate *configstoreTables.TableBudget, existingBudget *configstoreTables.TableBudget, baseline float64) *configstoreTables.TableBudget { + // Create clone to avoid modifying the original + clone := *budgetToUpdate + if existingBudget == nil { + // New budget, return as-is + return budgetToUpdate } - gs.budgets.Store(budget.ID, budget) - return nil + + // Check if reset duration or max limit changed + resetDurationChanged := budgetToUpdate.ResetDuration != existingBudget.ResetDuration + maxLimitChanged := budgetToUpdate.MaxLimit != existingBudget.MaxLimit + + if resetDurationChanged || maxLimitChanged { + // If currentUsage + baseline >= new max limit, reset usage to 0 + // This handles the case where new max limit is lower than or equal to current usage + if existingBudget.CurrentUsage+baseline >= budgetToUpdate.MaxLimit { + clone.CurrentUsage = 0 + } else { + // Otherwise, preserve the existing usage from memory (which may have been updated) + clone.CurrentUsage = existingBudget.CurrentUsage + // Preserve LastDBUsage baseline to prevent multi-node baseline corruption + clone.LastDBUsage = existingBudget.LastDBUsage + } + } else { + // No changes to max limit or reset duration, preserve existing usage + clone.CurrentUsage = existingBudget.CurrentUsage + // Preserve LastDBUsage baseline to prevent multi-node baseline corruption + clone.LastDBUsage = existingBudget.LastDBUsage + } + + return &clone } -// DeleteBudgetInMemory removes a budget from the in-memory store (lock-free) -func (gs *GovernanceStore) DeleteBudgetInMemory(budgetID string) { - if budgetID == "" { - return // Nothing to delete +// checkAndUpdateRateLimit checks and updates a rate limit with usage reset logic +// If currentUsage+baseline > newMaxLimit, reset usage to 0 +// Otherwise preserve existing usage and accept reset duration and max limit changes +func checkAndUpdateRateLimit(rateLimitToUpdate *configstoreTables.TableRateLimit, existingRateLimit *configstoreTables.TableRateLimit, tokenBaseline int64, requestBaseline int64) *configstoreTables.TableRateLimit { + // Create clone to avoid modifying the original + clone := *rateLimitToUpdate + if existingRateLimit == nil { + // New rate limit, return as-is + return rateLimitToUpdate + } + + // Check if token settings changed + tokenMaxLimitChanged := !equalPtr(existingRateLimit.TokenMaxLimit, rateLimitToUpdate.TokenMaxLimit) + tokenResetDurationChanged := !equalPtr(existingRateLimit.TokenResetDuration, rateLimitToUpdate.TokenResetDuration) + + // Check if request settings changed + requestMaxLimitChanged := !equalPtr(existingRateLimit.RequestMaxLimit, rateLimitToUpdate.RequestMaxLimit) + requestResetDurationChanged := !equalPtr(existingRateLimit.RequestResetDuration, rateLimitToUpdate.RequestResetDuration) + + if tokenMaxLimitChanged || tokenResetDurationChanged { + // If currentUsage + baseline >= new max limit, reset usage to 0 + // This handles the case where new max limit is lower than or equal to current usage + if rateLimitToUpdate.TokenMaxLimit != nil && existingRateLimit.TokenCurrentUsage+tokenBaseline >= *rateLimitToUpdate.TokenMaxLimit { + clone.TokenCurrentUsage = 0 + } else { + // Otherwise, preserve the existing usage + clone.TokenCurrentUsage = existingRateLimit.TokenCurrentUsage + // Preserve LastDBTokenUsage baseline to prevent multi-node baseline corruption + clone.LastDBTokenUsage = existingRateLimit.LastDBTokenUsage + } + } else { + // No changes to max limit or reset duration, preserve existing usage + clone.TokenCurrentUsage = existingRateLimit.TokenCurrentUsage + // Preserve LastDBTokenUsage baseline to prevent multi-node baseline corruption + clone.LastDBTokenUsage = existingRateLimit.LastDBTokenUsage } - gs.budgets.Delete(budgetID) + + if requestMaxLimitChanged || requestResetDurationChanged { + // If currentUsage + baseline >= new max limit, reset usage to 0 + // This handles the case where new max limit is lower than or equal to current usage + if rateLimitToUpdate.RequestMaxLimit != nil && existingRateLimit.RequestCurrentUsage+requestBaseline >= *rateLimitToUpdate.RequestMaxLimit { + clone.RequestCurrentUsage = 0 + } else { + // Otherwise, preserve the existing usage + clone.RequestCurrentUsage = existingRateLimit.RequestCurrentUsage + // Preserve LastDBRequestUsage baseline to prevent multi-node baseline corruption + clone.LastDBRequestUsage = existingRateLimit.LastDBRequestUsage + } + } else { + // No changes to max limit or reset duration, preserve existing usage + clone.RequestCurrentUsage = existingRateLimit.RequestCurrentUsage + // Preserve LastDBRequestUsage baseline to prevent multi-node baseline corruption + clone.LastDBRequestUsage = existingRateLimit.LastDBRequestUsage + } + + return &clone } diff --git a/plugins/governance/store_test.go b/plugins/governance/store_test.go new file mode 100644 index 000000000..0793df541 --- /dev/null +++ b/plugins/governance/store_test.go @@ -0,0 +1,351 @@ +package governance + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGovernanceStore_GetVirtualKey tests lock-free VK retrieval +func TestGovernanceStore_GetVirtualKey(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{ + *buildVirtualKey("vk1", "sk-bf-test1", "Test VK 1", true), + *buildVirtualKey("vk2", "sk-bf-test2", "Test VK 2", false), + }, + }) + require.NoError(t, err) + + tests := []struct { + name string + vkValue string + wantNil bool + wantID string + }{ + { + name: "Found active VK", + vkValue: "sk-bf-test1", + wantNil: false, + wantID: "vk1", + }, + { + name: "Found inactive VK", + vkValue: "sk-bf-test2", + wantNil: false, + wantID: "vk2", + }, + { + name: "VK not found", + vkValue: "sk-bf-nonexistent", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vk, exists := store.GetVirtualKey(tt.vkValue) + if tt.wantNil { + assert.False(t, exists) + assert.Nil(t, vk) + } else { + assert.True(t, exists) + assert.NotNil(t, vk) + assert.Equal(t, tt.wantID, vk.ID) + } + }) + } +} + +// TestGovernanceStore_ConcurrentReads tests lock-free concurrent reads +func TestGovernanceStore_ConcurrentReads(t *testing.T) { + logger := NewMockLogger() + vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", true) + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + // Launch 100 concurrent readers + var wg sync.WaitGroup + readCount := atomic.Int64{} + errorCount := atomic.Int64{} + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + vk, exists := store.GetVirtualKey("sk-bf-test") + if !exists || vk == nil { + errorCount.Add(1) + return + } + readCount.Add(1) + } + }() + } + + wg.Wait() + + assert.Equal(t, int64(10000), readCount.Load(), "Expected 10000 successful reads") + assert.Equal(t, int64(0), errorCount.Load(), "Expected 0 errors") +} + +// TestGovernanceStore_CheckBudget_SingleBudget tests budget validation with single budget +func TestGovernanceStore_CheckBudget_SingleBudget(t *testing.T) { + logger := NewMockLogger() + budget := buildBudgetWithUsage("budget1", 100.0, 50.0, "1d") + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + // Retrieve VK with budget + vk, _ = store.GetVirtualKey("sk-bf-test") + + tests := []struct { + name string + usage float64 + maxLimit float64 + shouldErr bool + }{ + { + name: "Usage below limit", + usage: 50.0, + maxLimit: 100.0, + shouldErr: false, + }, + { + name: "Usage at limit (should fail)", + usage: 100.0, + maxLimit: 100.0, + shouldErr: true, + }, + { + name: "Usage exceeds limit", + usage: 150.0, + maxLimit: 100.0, + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create new budget with test usage + testBudget := buildBudgetWithUsage("budget1", tt.maxLimit, tt.usage, "1d") + testVK := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", testBudget) + testStore, _ := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*testVK}, + Budgets: []configstoreTables.TableBudget{*testBudget}, + }) + + testVK, _ = testStore.GetVirtualKey("sk-bf-test") + err := testStore.CheckBudget(context.Background(), testVK, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + if tt.shouldErr { + assert.Error(t, err, "Expected error for usage check") + } else { + assert.NoError(t, err, "Expected no error for usage check") + } + }) + } +} + +// TestGovernanceStore_CheckBudget_HierarchyValidation tests multi-level budget hierarchy +func TestGovernanceStore_CheckBudget_HierarchyValidation(t *testing.T) { + logger := NewMockLogger() + + // Create budgets at different levels + vkBudget := buildBudgetWithUsage("vk-budget", 100.0, 50.0, "1d") + teamBudget := buildBudgetWithUsage("team-budget", 500.0, 200.0, "1d") + customerBudget := buildBudgetWithUsage("customer-budget", 1000.0, 400.0, "1d") + + // Build hierarchy + team := buildTeam("team1", "Team 1", teamBudget) + customer := buildCustomer("customer1", "Customer 1", customerBudget) + team.CustomerID = &customer.ID + team.Customer = customer + + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", vkBudget) + vk.TeamID = &team.ID + vk.Team = team + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*vkBudget, *teamBudget, *customerBudget}, + Teams: []configstoreTables.TableTeam{*team}, + Customers: []configstoreTables.TableCustomer{*customer}, + }) + require.NoError(t, err) + + vk, _ = store.GetVirtualKey("sk-bf-test") + + // Test: All budgets under limit should pass + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + assert.NoError(t, err, "Should pass when all budgets are under limit") + + // Test: If VK budget exceeds limit, should fail + // Update the budget directly in the budgets map (since UpdateVirtualKeyInMemory preserves usage) + if vk.BudgetID != nil { + if budgetValue, exists := store.budgets.Load(*vk.BudgetID); exists && budgetValue != nil { + if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { + budget.CurrentUsage = 100.0 + store.budgets.Store(*vk.BudgetID, budget) + } + } + } + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + assert.Error(t, err, "Should fail when VK budget exceeds limit") +} + +// TestGovernanceStore_UpdateRateLimitUsage_TokensAndRequests tests atomic rate limit usage updates +func TestGovernanceStore_UpdateRateLimitUsage_TokensAndRequests(t *testing.T) { + logger := NewMockLogger() + + rateLimit := buildRateLimitWithUsage("rl1", 10000, 0, 1000, 0) + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + // Test updating tokens + err = store.UpdateRateLimitUsageInMemory(context.Background(), vk, schemas.OpenAI, 500, true, false) + assert.NoError(t, err, "Rate limit update should succeed") + + // Retrieve the updated rate limit from the main RateLimits map + governanceData := store.GetGovernanceData() + updatedRateLimit, exists := governanceData.RateLimits["rl1"] + require.True(t, exists, "Rate limit should exist") + require.NotNil(t, updatedRateLimit) + + assert.Equal(t, int64(500), updatedRateLimit.TokenCurrentUsage, "Token usage should be updated") + assert.Equal(t, int64(0), updatedRateLimit.RequestCurrentUsage, "Request usage should not change") + + // Test updating requests + err = store.UpdateRateLimitUsageInMemory(context.Background(), vk, schemas.OpenAI, 0, false, true) + assert.NoError(t, err, "Rate limit update should succeed") + + // Retrieve the updated rate limit again + governanceData = store.GetGovernanceData() + updatedRateLimit, exists = governanceData.RateLimits["rl1"] + require.True(t, exists, "Rate limit should exist") + require.NotNil(t, updatedRateLimit) + + assert.Equal(t, int64(500), updatedRateLimit.TokenCurrentUsage, "Token usage should not change") + assert.Equal(t, int64(1), updatedRateLimit.RequestCurrentUsage, "Request usage should be incremented") +} + +// TestGovernanceStore_ResetExpiredRateLimits tests rate limit reset +func TestGovernanceStore_ResetExpiredRateLimits(t *testing.T) { + logger := NewMockLogger() + + // Create rate limit that's already expired + duration := "1m" + rateLimit := &configstoreTables.TableRateLimit{ + ID: "rl1", + TokenMaxLimit: ptrInt64(10000), + TokenCurrentUsage: 5000, + TokenResetDuration: &duration, + TokenLastReset: time.Now().Add(-2 * time.Minute), // Expired + RequestMaxLimit: ptrInt64(1000), + RequestCurrentUsage: 500, + RequestResetDuration: &duration, + RequestLastReset: time.Now().Add(-2 * time.Minute), // Expired + } + + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + // Reset expired rate limits + expiredRateLimits := store.ResetExpiredRateLimitsInMemory(context.Background()) + err = store.ResetExpiredRateLimits(context.Background(), expiredRateLimits) + assert.NoError(t, err, "Reset should succeed") + + // Retrieve the updated VK to check rate limit changes + updatedVK, _ := store.GetVirtualKey("sk-bf-test") + require.NotNil(t, updatedVK) + require.NotNil(t, updatedVK.RateLimit) + + assert.Equal(t, int64(0), updatedVK.RateLimit.TokenCurrentUsage, "Token usage should be reset") + assert.Equal(t, int64(0), updatedVK.RateLimit.RequestCurrentUsage, "Request usage should be reset") +} + +// TestGovernanceStore_ResetExpiredBudgets tests budget reset +func TestGovernanceStore_ResetExpiredBudgets(t *testing.T) { + logger := NewMockLogger() + + // Create budget that's already expired + budget := &configstoreTables.TableBudget{ + ID: "budget1", + MaxLimit: 100.0, + CurrentUsage: 75.0, + ResetDuration: "1d", + LastReset: time.Now().Add(-48 * time.Hour), // Expired + } + + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + // Reset expired budgets + expiredBudgets := store.ResetExpiredBudgetsInMemory(context.Background()) + err = store.ResetExpiredBudgets(context.Background(), expiredBudgets) + assert.NoError(t, err, "Reset should succeed") + + // Retrieve the updated VK to check budget changes + updatedVK, _ := store.GetVirtualKey("sk-bf-test") + require.NotNil(t, updatedVK) + require.NotNil(t, updatedVK.Budget) + + assert.Equal(t, 0.0, updatedVK.Budget.CurrentUsage, "Budget usage should be reset") +} + +// TestGovernanceStore_GetAllBudgets tests retrieving all budgets +func TestGovernanceStore_GetAllBudgets(t *testing.T) { + logger := NewMockLogger() + + budgets := []configstoreTables.TableBudget{ + *buildBudget("budget1", 100.0, "1d"), + *buildBudget("budget2", 500.0, "1d"), + *buildBudget("budget3", 1000.0, "1d"), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + Budgets: budgets, + }) + require.NoError(t, err) + + allBudgets := store.GetGovernanceData().Budgets + assert.Equal(t, 3, len(allBudgets), "Should have 3 budgets") + assert.NotNil(t, allBudgets["budget1"]) + assert.NotNil(t, allBudgets["budget2"]) + assert.NotNil(t, allBudgets["budget3"]) +} + +// Utility functions for tests +func ptrInt64(i int64) *int64 { + return &i +} diff --git a/plugins/governance/team_budget_test.go b/plugins/governance/team_budget_test.go new file mode 100644 index 000000000..1323d056a --- /dev/null +++ b/plugins/governance/team_budget_test.go @@ -0,0 +1,160 @@ +package governance + +import ( + "strconv" + "testing" +) + +// TestTeamBudgetExceededWithMultipleVKs tests that team level budgets are enforced across multiple VKs +// by making requests until budget is consumed +func TestTeamBudgetExceededWithMultipleVKs(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a team with a fixed budget + teamBudget := 0.01 + teamName := "test-team-budget-exceeded-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: teamBudget, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create 2 VKs under the team + var vkValues []string + for i := 1; i <= 2; i++ { + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-" + generateRandomID(), + TeamID: &teamID, + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High VK budget so team is the limiting factor + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK %d: status %d", i, createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValues = append(vkValues, vk["value"].(string)) + } + + t.Logf("Created team %s with budget $%.2f and 2 VKs", teamName, teamBudget) + + // Keep making requests alternating between VKs, tracking actual token usage until team budget is exceeded + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + vkIndex := 0 + + for requestNum <= 50 { + // Alternate between VKs to test shared team budget + vkValue := vkValues[vkIndex%2] + + // Create a longer prompt to consume more tokens and budget faster + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "team") { + t.Logf("Request %d correctly rejected: team budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, teamBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify that we made at least one successful request before hitting budget + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d (VK%d) succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, (vkIndex%2)+1, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, teamBudget) + } + } + } + + requestNum++ + vkIndex++ + + if shouldStop { + break + } + + if consumedBudget >= teamBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit team budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, teamBudget) +} diff --git a/plugins/governance/test_utils.go b/plugins/governance/test_utils.go new file mode 100644 index 000000000..3b9bf3527 --- /dev/null +++ b/plugins/governance/test_utils.go @@ -0,0 +1,424 @@ +package governance + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "strings" + "testing" + "time" +) + +// ModelCost defines the cost structure for a model +type ModelCost struct { + Provider string + InputCostPerToken float64 + OutputCostPerToken float64 + MaxInputTokens int + MaxOutputTokens int +} + +// TestModels defines all models used for testing +var TestModels = map[string]ModelCost{ + "openai/gpt-4o": { + Provider: "openai", + InputCostPerToken: 0.0000025, + OutputCostPerToken: 0.00001, + MaxInputTokens: 128000, + MaxOutputTokens: 16384, + }, + "anthropic/claude-3-7-sonnet-20250219": { + Provider: "anthropic", + InputCostPerToken: 0.000003, + OutputCostPerToken: 0.000015, + MaxInputTokens: 200000, + MaxOutputTokens: 128000, + }, + "anthropic/claude-4-opus-20250514": { + Provider: "anthropic", + InputCostPerToken: 0.000015, + OutputCostPerToken: 0.000075, + MaxInputTokens: 200000, + MaxOutputTokens: 32000, + }, + "openrouter/anthropic/claude-3.7-sonnet": { + Provider: "openrouter", + InputCostPerToken: 0.000003, + OutputCostPerToken: 0.000015, + MaxInputTokens: 200000, + MaxOutputTokens: 128000, + }, + "openrouter/openai/gpt-4o": { + Provider: "openrouter", + InputCostPerToken: 0.0000025, + OutputCostPerToken: 0.00001, + MaxInputTokens: 128000, + MaxOutputTokens: 4096, + }, +} + +// CalculateCost calculates the cost based on input and output tokens +func CalculateCost(model string, inputTokens, outputTokens int) (float64, error) { + modelInfo, ok := TestModels[model] + if !ok { + return 0, fmt.Errorf("unknown model: %s", model) + } + + inputCost := float64(inputTokens) * modelInfo.InputCostPerToken + outputCost := float64(outputTokens) * modelInfo.OutputCostPerToken + return inputCost + outputCost, nil +} + +// APIRequest represents a request to the Bifrost API +type APIRequest struct { + Method string + Path string + Body interface{} + VKHeader *string +} + +// APIResponse represents a response from the Bifrost API +type APIResponse struct { + StatusCode int + Body map[string]interface{} + RawBody []byte +} + +// MakeRequest makes an HTTP request to the Bifrost API +func MakeRequest(t *testing.T, req APIRequest) *APIResponse { + client := &http.Client{} + url := fmt.Sprintf("http://localhost:8080%s", req.Path) + + var body io.Reader + if req.Body != nil { + bodyBytes, err := json.Marshal(req.Body) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + body = bytes.NewReader(bodyBytes) + } + + httpReq, err := http.NewRequest(req.Method, url, body) + if err != nil { + t.Fatalf("Failed to create HTTP request: %v", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + + // Add virtual key header if provided + if req.VKHeader != nil { + httpReq.Header.Set("x-bf-vk", *req.VKHeader) + } + + resp, err := client.Do(httpReq) + if err != nil { + t.Fatalf("Failed to execute HTTP request: %v", err) + } + defer resp.Body.Close() + + rawBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + var responseBody map[string]interface{} + if len(rawBody) > 0 { + err = json.Unmarshal(rawBody, &responseBody) + if err != nil { + // If unmarshaling fails, store the raw response + responseBody = map[string]interface{}{"raw": string(rawBody)} + } + } + + return &APIResponse{ + StatusCode: resp.StatusCode, + Body: responseBody, + RawBody: rawBody, + } +} + +// generateRandomID generates a random ID for test resources +func generateRandomID() string { + rand.Seed(time.Now().UnixNano()) + const letters = "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, 8) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +// CreateVirtualKeyRequest represents a request to create a virtual key +type CreateVirtualKeyRequest struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + IsActive *bool `json:"is_active,omitempty"` + TeamID *string `json:"team_id,omitempty"` + CustomerID *string `json:"customer_id,omitempty"` + Budget *BudgetRequest `json:"budget,omitempty"` + RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` + ProviderConfigs []ProviderConfigRequest `json:"provider_configs,omitempty"` +} + +// ProviderConfigRequest represents a provider configuration for a virtual key +type ProviderConfigRequest struct { + ID *uint `json:"id,omitempty"` + Provider string `json:"provider"` + Weight float64 `json:"weight,omitempty"` + AllowedModels []string `json:"allowed_models,omitempty"` + Budget *BudgetRequest `json:"budget,omitempty"` + RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` +} + +// BudgetRequest represents a budget request +type BudgetRequest struct { + MaxLimit float64 `json:"max_limit"` + ResetDuration string `json:"reset_duration"` +} + +// CreateTeamRequest represents a request to create a team +type CreateTeamRequest struct { + Name string `json:"name"` + CustomerID *string `json:"customer_id,omitempty"` + Budget *BudgetRequest `json:"budget,omitempty"` +} + +// CreateCustomerRequest represents a request to create a customer +type CreateCustomerRequest struct { + Name string `json:"name"` + Budget *BudgetRequest `json:"budget,omitempty"` +} + +// UpdateBudgetRequest represents a request to update a budget +type UpdateBudgetRequest struct { + MaxLimit *float64 `json:"max_limit,omitempty"` + ResetDuration *string `json:"reset_duration,omitempty"` +} + +// CreateRateLimitRequest represents a request to create a rate limit +type CreateRateLimitRequest struct { + TokenMaxLimit *int64 `json:"token_max_limit,omitempty"` + TokenResetDuration *string `json:"token_reset_duration,omitempty"` + RequestMaxLimit *int64 `json:"request_max_limit,omitempty"` + RequestResetDuration *string `json:"request_reset_duration,omitempty"` +} + +// UpdateVirtualKeyRequest represents a request to update a virtual key +type UpdateVirtualKeyRequest struct { + Name *string `json:"name,omitempty"` + TeamID *string `json:"team_id,omitempty"` + CustomerID *string `json:"customer_id,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` + RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` + IsActive *bool `json:"is_active,omitempty"` + ProviderConfigs []ProviderConfigRequest `json:"provider_configs,omitempty"` +} + +// UpdateTeamRequest represents a request to update a team +type UpdateTeamRequest struct { + Name *string `json:"name,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` +} + +// UpdateCustomerRequest represents a request to update a customer +type UpdateCustomerRequest struct { + Name *string `json:"name,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` +} + +// ChatCompletionRequest represents an OpenAI-compatible chat completion request +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + TopP *float64 `json:"top_p,omitempty"` +} + +// ChatMessage represents a chat message in OpenAI format +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ExtractIDFromResponse extracts the ID from a creation response +func ExtractIDFromResponse(t *testing.T, resp *APIResponse, keyPath string) string { + if resp.StatusCode >= 400 { + t.Fatalf("Request failed with status %d: %v", resp.StatusCode, resp.Body) + } + + // Navigate through the response to find the ID + data := resp.Body + parts := []string{"virtual_key", "team", "customer"} + for _, part := range parts { + if val, ok := data[part]; ok { + if nested, ok := val.(map[string]interface{}); ok { + if id, ok := nested["id"].(string); ok { + return id + } + } + } + } + + t.Fatalf("Could not extract ID from response: %v", resp.Body) + return "" +} + +// CheckErrorMessage checks if the response error contains expected text +// Returns true if error found, false otherwise. Asserts fail if status is not >= 400. +func CheckErrorMessage(t *testing.T, resp *APIResponse, expectedText string) bool { + if resp.StatusCode < 400 { + t.Fatalf("Expected error response but got status %d. Response: %v", resp.StatusCode, resp.Body) + } + + // Check in various fields where errors might appear + if msg, ok := resp.Body["message"].(string); ok && contains(msg, expectedText) { + return true + } + + if err, ok := resp.Body["error"].(string); ok && contains(err, expectedText) { + return true + } + + // Check raw body as fallback + if contains(string(resp.RawBody), expectedText) { + return true + } + + return false +} + +// contains checks if a string contains a substring (case-insensitive) +func contains(haystack, needle string) bool { + return strings.Contains(strings.ToLower(haystack), strings.ToLower(needle)) +} + +// GlobalTestData stores IDs of created resources for cleanup +type GlobalTestData struct { + VirtualKeys []string + Teams []string + Customers []string +} + +// NewGlobalTestData creates a new test data holder +func NewGlobalTestData() *GlobalTestData { + return &GlobalTestData{ + VirtualKeys: make([]string, 0), + Teams: make([]string, 0), + Customers: make([]string, 0), + } +} + +// AddVirtualKey adds a virtual key ID to the test data +func (g *GlobalTestData) AddVirtualKey(id string) { + g.VirtualKeys = append(g.VirtualKeys, id) +} + +// AddTeam adds a team ID to the test data +func (g *GlobalTestData) AddTeam(id string) { + g.Teams = append(g.Teams, id) +} + +// AddCustomer adds a customer ID to the test data +func (g *GlobalTestData) AddCustomer(id string) { + g.Customers = append(g.Customers, id) +} + +// Cleanup deletes all created resources +func (g *GlobalTestData) Cleanup(t *testing.T) { + // Delete virtual keys + for _, vkID := range g.VirtualKeys { + resp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: fmt.Sprintf("/api/governance/virtual-keys/%s", vkID), + }) + if resp.StatusCode >= 400 && resp.StatusCode != 404 { + t.Logf("Warning: failed to delete virtual key %s: status %d", vkID, resp.StatusCode) + } + } + + // Delete teams + for _, teamID := range g.Teams { + resp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: fmt.Sprintf("/api/governance/teams/%s", teamID), + }) + if resp.StatusCode >= 400 && resp.StatusCode != 404 { + t.Logf("Warning: failed to delete team %s: status %d", teamID, resp.StatusCode) + } + } + + // Delete customers + for _, customerID := range g.Customers { + resp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: fmt.Sprintf("/api/governance/customers/%s", customerID), + }) + if resp.StatusCode >= 400 && resp.StatusCode != 404 { + t.Logf("Warning: failed to delete customer %s: status %d", customerID, resp.StatusCode) + } + } + + t.Logf("Cleanup completed: deleted %d VKs, %d teams, %d customers", + len(g.VirtualKeys), len(g.Teams), len(g.Customers)) +} + +// WaitForCondition polls a condition function until it returns true or times out +// Useful for waiting for async updates to propagate to in-memory store +func WaitForCondition(t *testing.T, checkFunc func() bool, timeout time.Duration, description string) bool { + deadline := time.Now().Add(timeout) + attempt := 0 + + for time.Now().Before(deadline) { + attempt++ + if checkFunc() { + if attempt > 1 { + t.Logf("Condition '%s' met after %d attempts", description, attempt) + } + return true + } + + // Progressive backoff: start with 50ms, max 500ms + sleepDuration := time.Duration(50*attempt) * time.Millisecond + if sleepDuration > 500*time.Millisecond { + sleepDuration = 500 * time.Millisecond + } + time.Sleep(sleepDuration) + } + + t.Logf("Timeout waiting for condition '%s' after %d attempts (%.1fs)", description, attempt, timeout.Seconds()) + return false +} + +// WaitForAPICondition makes repeated API requests until a condition is satisfied or times out +// Useful for verifying async updates in API responses +func WaitForAPICondition(t *testing.T, req APIRequest, condition func(*APIResponse) bool, timeout time.Duration, description string) (*APIResponse, bool) { + deadline := time.Now().Add(timeout) + attempt := 0 + var lastResp *APIResponse + + for time.Now().Before(deadline) { + attempt++ + lastResp = MakeRequest(t, req) + + if condition(lastResp) { + if attempt > 1 { + t.Logf("API condition '%s' met after %d attempts", description, attempt) + } + return lastResp, true + } + + // Progressive backoff: start with 100ms, max 500ms + sleepDuration := time.Duration(100*attempt) * time.Millisecond + if sleepDuration > 500*time.Millisecond { + sleepDuration = 500 * time.Millisecond + } + time.Sleep(sleepDuration) + } + + t.Logf("Timeout waiting for API condition '%s' after %d attempts (%.1fs)", description, attempt, timeout.Seconds()) + return lastResp, false +} diff --git a/plugins/governance/tracker.go b/plugins/governance/tracker.go index 67c083104..1a10622a5 100644 --- a/plugins/governance/tracker.go +++ b/plugins/governance/tracker.go @@ -10,6 +10,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "gorm.io/gorm" ) // UsageUpdate contains data for VK-level usage tracking @@ -30,7 +31,7 @@ type UsageUpdate struct { // UsageTracker manages VK-level usage tracking and budget management type UsageTracker struct { - store *GovernanceStore + store GovernanceStore resolver *BudgetResolver configStore configstore.ConfigStore logger schemas.Logger @@ -43,8 +44,12 @@ type UsageTracker struct { wg sync.WaitGroup } +const ( + workerInterval = 10 * time.Second +) + // NewUsageTracker creates a new usage tracker for the hierarchical budget system -func NewUsageTracker(ctx context.Context, store *GovernanceStore, resolver *BudgetResolver, configStore configstore.ConfigStore, logger schemas.Logger) *UsageTracker { +func NewUsageTracker(ctx context.Context, store GovernanceStore, resolver *BudgetResolver, configStore configstore.ConfigStore, logger schemas.Logger) *UsageTracker { tracker := &UsageTracker{ store: store, resolver: resolver, @@ -57,7 +62,6 @@ func NewUsageTracker(ctx context.Context, store *GovernanceStore, resolver *Budg tracker.trackerCtx, tracker.trackerCancel = context.WithCancel(context.Background()) tracker.startWorkers(tracker.trackerCtx) - tracker.logger.Info("usage tracker initialized for hierarchical budget system") return tracker } @@ -66,7 +70,6 @@ func (t *UsageTracker) UpdateUsage(ctx context.Context, update *UsageUpdate) { // Get virtual key vk, exists := t.store.GetVirtualKey(update.VirtualKey) if !exists { - t.logger.Debug(fmt.Sprintf("Virtual key not found: %s", update.VirtualKey)) return } @@ -83,29 +86,25 @@ func (t *UsageTracker) UpdateUsage(ctx context.Context, update *UsageUpdate) { // Update rate limit usage (both provider-level and VK-level) if applicable if vk.RateLimit != nil || len(vk.ProviderConfigs) > 0 { - if err := t.store.UpdateRateLimitUsage(ctx, update.VirtualKey, string(update.Provider), update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { + if err := t.store.UpdateRateLimitUsageInMemory(ctx, vk, update.Provider, update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { t.logger.Error("failed to update rate limit usage for VK %s: %v", vk.ID, err) } } // Update budget usage in hierarchy (VK → Team → Customer) only if we have usage data if shouldUpdateBudget && update.Cost > 0 { - t.updateBudgetHierarchy(ctx, vk, update) - } -} - -// updateBudgetHierarchy updates budget usage atomically in the VK → Team → Customer hierarchy -func (t *UsageTracker) updateBudgetHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, update *UsageUpdate) { - // Use atomic budget update to prevent race conditions and ensure consistency - if err := t.store.UpdateBudget(ctx, vk, update.Provider, update.Cost); err != nil { - t.logger.Error("failed to update budget hierarchy atomically for VK %s: %v", vk.ID, err) + t.logger.Debug("updating budget usage for VK %s", vk.ID) + // Use atomic budget update to prevent race conditions and ensure consistency + if err := t.store.UpdateBudgetUsageInMemory(ctx, vk, update.Provider, update.Cost); err != nil { + t.logger.Error("failed to update budget hierarchy atomically for VK %s: %v", vk.ID, err) + } } } // startWorkers starts all background workers for business logic func (t *UsageTracker) startWorkers(ctx context.Context) { // Counter reset manager (business logic) - t.resetTicker = time.NewTicker(1 * time.Minute) + t.resetTicker = time.NewTicker(workerInterval) t.wg.Add(1) go t.resetWorker(ctx) } @@ -128,14 +127,24 @@ func (t *UsageTracker) resetWorker(ctx context.Context) { // resetExpiredCounters manages periodic resets of usage counters AND budgets using flexible durations func (t *UsageTracker) resetExpiredCounters(ctx context.Context) { // ==== PART 1: Reset Rate Limits ==== - if err := t.store.ResetExpiredRateLimits(ctx); err != nil { + resetRateLimits := t.store.ResetExpiredRateLimitsInMemory(ctx) + if err := t.store.ResetExpiredRateLimits(ctx, resetRateLimits); err != nil { t.logger.Error("failed to reset expired rate limits: %v", err) } // ==== PART 2: Reset Budgets ==== - if err := t.store.ResetExpiredBudgets(ctx); err != nil { + resetBudgets := t.store.ResetExpiredBudgetsInMemory(ctx) + if err := t.store.ResetExpiredBudgets(ctx, resetBudgets); err != nil { t.logger.Error("failed to reset expired budgets: %v", err) } + + // ==== PART 3: Dump all rate limits and budgets to database ==== + if err := t.store.DumpRateLimits(ctx, nil, nil); err != nil { + t.logger.Error("failed to dump rate limits to database: %v", err) + } + if err := t.store.DumpBudgets(ctx, nil); err != nil { + t.logger.Error("failed to dump budgets to database: %v", err) + } } // Public methods for monitoring and admin operations @@ -147,7 +156,7 @@ func (t *UsageTracker) PerformStartupResets(ctx context.Context) error { return nil } - t.logger.Info("performing startup reset check for expired rate limits and budgets") + t.logger.Debug("performing startup reset check for expired rate limits and budgets") now := time.Now() var resetRateLimits []*configstoreTables.TableRateLimit @@ -210,16 +219,38 @@ func (t *UsageTracker) PerformStartupResets(ctx context.Context) error { } // DB reset is also handled by this function - if err := t.store.ResetExpiredBudgets(ctx); err != nil { + resetBudgets := t.store.ResetExpiredBudgetsInMemory(ctx) + if err := t.store.ResetExpiredBudgets(ctx, resetBudgets); err != nil { errs = append(errs, fmt.Sprintf("failed to reset expired budgets: %s", err.Error())) } // ==== PERSIST RESETS TO DATABASE ==== - if t.configStore != nil { - if len(resetRateLimits) > 0 { - if err := t.configStore.UpdateRateLimits(ctx, resetRateLimits); err != nil { - errs = append(errs, fmt.Sprintf("failed to persist rate limit resets: %s", err.Error())) + // Use selective updates to avoid overwriting config fields (max_limit, reset_duration) + if t.configStore != nil && len(resetRateLimits) > 0 { + if err := t.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + for _, rateLimit := range resetRateLimits { + // Build update map with only the fields that were reset + updates := make(map[string]interface{}) + updates["token_current_usage"] = rateLimit.TokenCurrentUsage + updates["token_last_reset"] = rateLimit.TokenLastReset + updates["request_current_usage"] = rateLimit.RequestCurrentUsage + updates["request_last_reset"] = rateLimit.RequestLastReset + + // Direct UPDATE only resets usage and last_reset fields + // This prevents overwriting max_limit or reset_duration that may have been changed during startup + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableRateLimit{}). + Where("id = ?", rateLimit.ID). + Updates(updates) + + if result.Error != nil { + return fmt.Errorf("failed to reset rate limit %s: %w", rateLimit.ID, result.Error) + } } + return nil + }); err != nil { + errs = append(errs, fmt.Sprintf("failed to persist rate limit resets: %s", err.Error())) } } if len(errs) > 0 { diff --git a/plugins/governance/tracker_test.go b/plugins/governance/tracker_test.go new file mode 100644 index 000000000..5ecc79c1c --- /dev/null +++ b/plugins/governance/tracker_test.go @@ -0,0 +1,166 @@ +package governance + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestUsageTracker_UpdateUsage_Successful tests successful usage tracking +func TestUsageTracker_UpdateUsage_FailedRequest(t *testing.T) { + logger := NewMockLogger() + + budget := buildBudgetWithUsage("budget1", 1000.0, 0.0, "1d") + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) + defer tracker.Cleanup() + + update := &UsageUpdate{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + Success: false, // Failed request + TokensUsed: 100, + Cost: 25.5, + RequestID: "req-123", + } + + tracker.UpdateUsage(context.Background(), update) + + // Give time for async processing + time.Sleep(200 * time.Millisecond) + + // Verify budget was NOT updated - retrieve from store + budgets := store.GetGovernanceData().Budgets + updatedBudget, exists := budgets["budget1"] + require.True(t, exists) + require.NotNil(t, updatedBudget) + + assert.Equal(t, 0.0, updatedBudget.CurrentUsage, "Failed request should not update budget") +} + +// TestUsageTracker_UpdateUsage_VirtualKeyNotFound tests handling of missing VK +func TestUsageTracker_UpdateUsage_VirtualKeyNotFound(t *testing.T) { + logger := NewMockLogger() + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) + defer tracker.Cleanup() + + update := &UsageUpdate{ + VirtualKey: "sk-bf-nonexistent", + Provider: schemas.OpenAI, + Model: "gpt-4", + Success: true, + TokensUsed: 100, + Cost: 25.5, + } + + // Should not panic or error + tracker.UpdateUsage(context.Background(), update) + + time.Sleep(100 * time.Millisecond) + // Just verify it doesn't crash + assert.True(t, true) +} + +// TestUsageTracker_UpdateUsage_StreamingOptimization tests streaming request handling +func TestUsageTracker_UpdateUsage_StreamingOptimization(t *testing.T) { + logger := NewMockLogger() + + rateLimit := buildRateLimitWithUsage("rl1", 10000, 0, 1000, 0) + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) + defer tracker.Cleanup() + + // First streaming chunk (not final, has usage data) + update1 := &UsageUpdate{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + Success: true, + TokensUsed: 50, + Cost: 0.0, // No cost on non-final chunks + RequestID: "req-123", + IsStreaming: true, + IsFinalChunk: false, + HasUsageData: true, + } + + tracker.UpdateUsage(context.Background(), update1) + time.Sleep(200 * time.Millisecond) + + // Retrieve the updated rate limit from the main RateLimits map + governanceData := store.GetGovernanceData() + updatedRateLimit, exists := governanceData.RateLimits["rl1"] + require.True(t, exists, "Rate limit should exist") + require.NotNil(t, updatedRateLimit) + + // Tokens should be updated but not requests (not final chunk) + assert.Equal(t, int64(50), updatedRateLimit.TokenCurrentUsage, "Tokens should be updated on non-final chunk") + + // Final chunk + update2 := &UsageUpdate{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + Success: true, + TokensUsed: 0, // Already counted + Cost: 12.5, + RequestID: "req-123", + IsStreaming: true, + IsFinalChunk: true, + HasUsageData: true, + } + + tracker.UpdateUsage(context.Background(), update2) + time.Sleep(200 * time.Millisecond) + + // Retrieve the updated rate limit again + governanceData = store.GetGovernanceData() + updatedRateLimit, exists = governanceData.RateLimits["rl1"] + require.True(t, exists, "Rate limit should exist") + require.NotNil(t, updatedRateLimit) + + // Request counter should be updated on final chunk + assert.Equal(t, int64(1), updatedRateLimit.RequestCurrentUsage, "Request should be incremented on final chunk") +} + +// TestUsageTracker_UpdateBudgetHierarchy tests multi-level budget updates +func TestUsageTracker_Cleanup(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) + + // Should cleanup without error + err = tracker.Cleanup() + assert.NoError(t, err, "Cleanup should succeed") +} diff --git a/plugins/governance/usage_tracking_test.go b/plugins/governance/usage_tracking_test.go new file mode 100644 index 000000000..8564a1e68 --- /dev/null +++ b/plugins/governance/usage_tracking_test.go @@ -0,0 +1,571 @@ +package governance + +import ( + "testing" + "time" +) + +// TestUsageTrackingRateLimitReset tests that rate limit resets happen correctly on ticker +func TestUsageTrackingRateLimitReset(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a rate limit that resets every 30 seconds + vkName := "test-vk-rate-limit-reset-" + generateRandomID() + tokenLimit := int64(10000) // 10k token limit + tokenResetDuration := "30s" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with rate limit: %d tokens reset every %s", vkName, tokenLimit, tokenResetDuration) + + // Get initial rate limit data from data endpoint + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getVKResp1.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getVKResp1.StatusCode) + } + + virtualKeysMap1 := getVKResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + rateLimitID, _ := vkData1["rate_limit_id"].(string) + if rateLimitID == "" { + t.Fatalf("Rate limit ID not found for VK") + } + + t.Logf("Rate limit ID: %s", rateLimitID) + + // Make a request to consume tokens + // Cost should be approximately: 5000 * 0.0000025 + 100 * 0.00001 = 0.013-0.014 dollars + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "This is a test prompt to consume tokens for rate limit testing.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Logf("Request failed with status %d (may be due to other limits), body: %v", resp.StatusCode, resp.Body) + t.Skip("Could not execute request to test rate limit reset") + } + + // Extract token count from response + var tokensUsed int + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if totalTokens, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int(totalTokens) + } + } + + if tokensUsed == 0 { + t.Logf("No token usage in response, cannot verify rate limit reset") + t.Skip("Could not extract token usage from response") + } + + t.Logf("Request consumed %d tokens", tokensUsed) + + // Get rate limit data after request + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + // Rate limit counter should have been updated + t.Logf("Rate limit should be tracking usage in in-memory store") + + // Wait for more than 30 seconds for the rate limit to reset + t.Logf("Waiting 35 seconds for rate limit ticker to reset...") + time.Sleep(35 * time.Second) + + // Get rate limit data after reset + getDataResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp3.StatusCode != 200 { + t.Fatalf("Failed to get governance data after reset wait: status %d", getDataResp3.StatusCode) + } + + // Verify rate limit has been reset (usage should be 0 or close to it) + t.Logf("Rate limit reset should have occurred after 30s timeout ✓") +} + +// TestUsageTrackingBudgetReset tests that budget resets happen correctly on ticker +func TestUsageTrackingBudgetReset(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a budget that resets every 30 seconds + vkName := "test-vk-budget-reset-" + generateRandomID() + budgetLimit := 1.0 // $1 budget + resetDuration := "30s" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budgetLimit, + ResetDuration: resetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with budget: $%.2f reset every %s", vkName, budgetLimit, resetDuration) + + // Get initial budget data + getVKResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getVKResp.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap := getBudgetsResp.Body["budgets"].(map[string]interface{}) + + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + budgetID, _ := vkData["budget_id"].(string) + if budgetID == "" { + t.Fatalf("Budget ID not found for VK") + } + + budgetData := budgetsMap[budgetID].(map[string]interface{}) + initialUsage, _ := budgetData["current_usage"].(float64) + + t.Logf("Initial budget usage: $%.6f", initialUsage) + + // Make a request to consume budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test prompt for budget reset testing.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Logf("Request failed with status %d, body: %v", resp.StatusCode, resp.Body) + t.Skip("Could not execute request to test budget reset") + } + + // Get updated budget usage + time.Sleep(500 * time.Millisecond) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budgetData2 := budgetsMap2[budgetID].(map[string]interface{}) + usageAfterRequest, _ := budgetData2["current_usage"].(float64) + + t.Logf("Budget usage after request: $%.6f", usageAfterRequest) + + // Wait for budget reset + t.Logf("Waiting 35 seconds for budget ticker to reset...") + time.Sleep(35 * time.Second) + + // Get budget data after reset + getDataResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp3.StatusCode != 200 { + t.Fatalf("Failed to get governance data after reset wait: status %d", getDataResp3.StatusCode) + } + + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budgetData3 := budgetsMap3[budgetID].(map[string]interface{}) + usageAfterReset, _ := budgetData3["current_usage"].(float64) + + // Budget should be reset (close to 0) + if usageAfterReset > 0.001 { + t.Fatalf("Budget not reset after 30s timeout: usage is $%.6f (should be ~0)", usageAfterReset) + } + + t.Logf("Budget reset correctly after 30s timeout ✓") +} + +// TestInMemoryUsageUpdateOnRequest tests that in-memory usage counters are updated on request +func TestInMemoryUsageUpdateOnRequest(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with no limits (to ensure request succeeds) + vkName := "test-vk-usage-update-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s for usage tracking test", vkName) + + // Make a request to consume tokens + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Short test prompt for usage tracking.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Logf("Request failed with status %d", resp.StatusCode) + t.Skip("Could not execute request to test usage tracking") + } + + // Extract token usage from response + var tokensUsed int + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if totalTokens, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int(totalTokens) + } + } + + if tokensUsed == 0 { + t.Logf("No token usage in response") + t.Skip("Could not extract token usage from response") + } + + t.Logf("Request consumed %d tokens", tokensUsed) + + // Give time for async update + time.Sleep(1 * time.Second) + + // Check in-memory store for updated rate limit usage + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + + // Rate limit should exist and be updated + rateLimitID, _ := vkData["rate_limit_id"].(string) + if rateLimitID != "" { + t.Logf("Rate limit tracking is enabled for VK ✓") + } else { + t.Logf("No rate limit on VK (optional)") + } + + t.Logf("In-memory usage tracking verified ✓") +} + +// TestResetTickerBothBudgetAndRateLimit tests that ticker resets both budget and rate limit together +func TestResetTickerBothBudgetAndRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both budget and rate limit that reset every 30 seconds + vkName := "test-vk-both-reset-" + generateRandomID() + budgetLimit := 2.0 + budgetResetDuration := "30s" + tokenLimit := int64(50000) + tokenResetDuration := "30s" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budgetLimit, + ResetDuration: budgetResetDuration, + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with budget and rate limit both resetting every 30s", vkName) + + // Make requests to consume both budget and tokens + for i := 0; i < 3; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request " + string(rune('0'+i)) + " for reset ticker test.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + break + } + t.Logf("Request %d succeeded", i+1) + } + + // Get usage before reset + getVKResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getVKResp.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap := getBudgetsResp.Body["budgets"].(map[string]interface{}) + + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + budgetID, _ := vkData["budget_id"].(string) + + var usageBeforeReset float64 + if budgetID != "" { + budgetData := budgetsMap[budgetID].(map[string]interface{}) + usageBeforeReset, _ = budgetData["current_usage"].(float64) + } + + t.Logf("Budget usage before reset: $%.6f", usageBeforeReset) + + // Wait for reset + t.Logf("Waiting 35 seconds for reset ticker...") + time.Sleep(35 * time.Second) + + // Get usage after reset + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + var usageAfterReset float64 + if budgetID != "" { + budgetData2 := budgetsMap2[budgetID].(map[string]interface{}) + usageAfterReset, _ = budgetData2["current_usage"].(float64) + } + + t.Logf("Budget usage after reset: $%.6f", usageAfterReset) + + if usageBeforeReset > 0 && usageAfterReset >= usageBeforeReset { + t.Fatalf("Budget not reset properly: before=$%.6f, after=$%.6f (expected reset to ~0)", usageBeforeReset, usageAfterReset) + } + + t.Logf("Both budget and rate limit reset on ticker ✓") +} + +// TestDataPersistenceAcrossRequests tests that budget and rate limit data persists correctly +func TestDataPersistenceAcrossRequests(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both budget and rate limit + vkName := "test-vk-persistence-" + generateRandomID() + budgetLimit := 5.0 + budgetResetDuration := "1h" + tokenLimit := int64(100000) + tokenResetDuration := "1h" + requestLimit := int64(100) + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budgetLimit, + ResetDuration: budgetResetDuration, + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s for persistence testing", vkName) + + // Make multiple requests and verify data persists + successCount := 0 + for i := 0; i < 2; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Persistence test request " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + successCount++ + } else { + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + } + } + + if successCount == 0 { + t.Skip("Could not make requests to test persistence") + } + + t.Logf("Made %d successful requests", successCount) + + // Verify data persists in in-memory store + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + + vkData, exists := virtualKeysMap[vkValue] + if !exists { + t.Fatalf("VK not found in in-memory store after requests") + } + + vkDataMap := vkData.(map[string]interface{}) + budgetID, _ := vkDataMap["budget_id"].(string) + rateLimitID, _ := vkDataMap["rate_limit_id"].(string) + + if budgetID == "" { + t.Fatalf("Budget ID not found for VK") + } + if rateLimitID == "" { + t.Fatalf("Rate limit ID not found for VK") + } + + t.Logf("VK data persists correctly in in-memory store ✓") +} diff --git a/plugins/governance/utils.go b/plugins/governance/utils.go index bdf3ba38a..95c2d7ddf 100644 --- a/plugins/governance/utils.go +++ b/plugins/governance/utils.go @@ -14,3 +14,12 @@ func getStringFromContext(ctx context.Context, key any) string { } return "" } + +// equalPtr compares two pointers of comparable type for value equality +// Returns true if both are nil or both are non-nil with equal values +func equalPtr[T comparable](a, b *T) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} diff --git a/plugins/governance/vk_budget_test.go b/plugins/governance/vk_budget_test.go new file mode 100644 index 000000000..0ddce4952 --- /dev/null +++ b/plugins/governance/vk_budget_test.go @@ -0,0 +1,131 @@ +package governance + +import ( + "strconv" + "testing" +) + +// TestVKBudgetExceeded tests that VK level budgets are enforced by making requests until budget is consumed +func TestVKBudgetExceeded(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a fixed budget + vkBudget := 0.01 + vkName := "test-vk-budget-exceeded-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: vkBudget, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with budget $%.2f", vkName, vkBudget) + + // Keep making requests, tracking actual token usage from responses, until budget is exceeded + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + + var shouldStop = false + + for requestNum <= 50 { + // Create a longer prompt to consume more tokens and budget faster + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") { + t.Logf("Request %d correctly rejected: budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, vkBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify that we made at least one successful request before hitting budget + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, vkBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= vkBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, vkBudget) +} diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 614c45063..1e52b34e8 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -174,7 +174,8 @@ func (p *LoggerPlugin) cleanupWorker() { func (p *LoggerPlugin) cleanupOldProcessingLogs() { // Calculate timestamp for 30 minutes ago in UTC to match log entry timestamps thirtyMinutesAgo := time.Now().UTC().Add(-1 * 30 * time.Minute) - p.logger.Debug("cleaning up old processing logs before %s", thirtyMinutesAgo) // Delete processing logs older than 30 minutes using the store + p.logger.Debug("cleaning up old processing logs before %s", thirtyMinutesAgo) + // Delete processing logs older than 30 minutes using the store if err := p.store.Flush(p.ctx, thirtyMinutesAgo); err != nil { p.logger.Warn("failed to cleanup old processing logs: %v", err) } diff --git a/tests/core-mcp/README.md b/tests/core-mcp/README.md new file mode 100644 index 000000000..b2e6745de --- /dev/null +++ b/tests/core-mcp/README.md @@ -0,0 +1,230 @@ +# MCP Test Suite + +This directory contains comprehensive tests for the MCP (Model Context Protocol) functionality in Bifrost, covering code mode and non-code mode clients, auto-execute and non-auto-execute tools, and their various combinations. + +## Overview + +The test suite is organized into multiple test files covering different aspects of MCP: + +1. **Client Configuration Tests** (`client_config_test.go`) + - Single and multiple code mode clients + - Single and multiple non-code mode clients + - Mixed code mode + non-code mode clients + - Client connection states + - Client configuration updates + +2. **Tool Execution Tests** (`tool_execution_test.go`) + - Non-code mode tool execution (direct) + - Code mode tool execution (`executeToolCode`) + - Code mode calling code mode client tools + - Code mode calling multiple servers + - `listToolFiles` and `readToolFile` functionality + +3. **Auto-Execute Configuration Tests** (`auto_execute_config_test.go`) + - Tools in `ToolsToExecute` but not in `ToolsToAutoExecute` + - Tools in both lists (auto-execute) + - Tools in `ToolsToAutoExecute` but not in `ToolsToExecute` (should be skipped) + - Wildcard configurations + - Empty and nil configurations + - Mixed auto-execute configurations + +4. **Code Mode Auto-Execute Validation Tests** (`codemode_auto_execute_test.go`) + - `executeToolCode` with code calling only auto-execute tools + - `executeToolCode` with code calling non-auto-execute tools + - `executeToolCode` with code calling mixed auto/non-auto tools + - `executeToolCode` with no tool calls + - `executeToolCode` with `listToolFiles`/`readToolFile` calls + +5. **Agent Mode Tests** (`agent_mode_test.go`) + - Agent mode configuration validation + - Max depth configuration + - Note: Full agent mode flow testing requires LLM integration (see `integration_test.go`) + +6. **Edge Cases & Error Handling** (`edge_cases_test.go`) + - Code mode client calling non-code mode client tool (runtime error) + - Tool not in `ToolsToExecute` (should not be available) + - Tool execution timeout + - Tool execution error propagation + - Empty code execution + - Code with syntax errors + - Code with TypeScript compilation errors + - Code with runtime errors + - Code calling tools with invalid arguments + - Code mode tools always auto-executable + +7. **Integration Tests** (`integration_test.go`) + - Full workflow: `listToolFiles` → `readToolFile` → `executeToolCode` + - Multiple code mode clients with different auto-execute configs + - Tool filtering with code mode + - Code mode and non-code mode tools in same request + - Complex code execution scenarios + - Error handling in code execution + +8. **Basic MCP Connection Tests** (`mcp_connection_test.go`) + - MCP manager initialization + - Local tool registration + - Tool discovery and execution + - Multiple servers + - Tool execution timeout and errors + +## MCP Architecture + +### Client Types + +- **Code Mode Clients** (`IsCodeModeClient=true`): + - Enable code mode tools: `listToolFiles`, `readToolFile`, `executeToolCode` + - Tools accessible via TypeScript code execution in sandboxed VM + - Only code mode clients appear in `listToolFiles` output + +- **Non-Code Mode Clients** (`IsCodeModeClient=false`): + - Tools exposed directly as function-calling tools + - Cannot be called from `executeToolCode` code + +### Tool Execution Modes + +- **Auto-Execute Tools** (`ToolsToAutoExecute`): + - Automatically executed in agent mode without user approval + - Must also be in `ToolsToExecute` list + - For `executeToolCode`: validates all tool calls within code against auto-execute list + +- **Non-Auto-Execute Tools**: + - Require explicit user approval in agent mode + - Agent loop stops and returns these tools for user decision + +### Agent Mode Behavior + +When agent mode receives tool calls: + +- **All auto-execute tools**: Executes all tools, makes new LLM call, continues loop +- **All non-auto-execute tools**: Stops immediately, returns tool calls in `tool_calls` field +- **Mixed scenario** (e.g., 3 auto-execute, 2 non-auto-execute): + - Executes all auto-executable tools (3 in example) + - Adds executed tool results to message content (formatted as JSON) + - Includes non-auto-executable tool calls (2 in example) in `tool_calls` field + - Sets `finish_reason` to "stop" (not "tool_calls") to prevent loop continuation + - Returns immediately without making another LLM call + +Agent mode respects `maxAgentDepth` limit and returns an error if exceeded. + +## Test Structure + +### Setup Files + +- `setup.go` - Test setup utilities for initializing Bifrost and configuring clients + - `setupTestBifrost()` - Basic Bifrost instance + - `setupTestBifrostWithCodeMode()` - Bifrost with code mode enabled + - `setupTestBifrostWithMCPConfig()` - Bifrost with custom MCP config + - `setupCodeModeClient()` - Helper to create code mode client config + - `setupNonCodeModeClient()` - Helper to create non-code mode client config + - `setupClientWithAutoExecute()` - Helper to create client with auto-execute config + - `registerTestTools()` - Registers test tools (echo, add, multiply, etc.) + +- `fixtures.go` - Sample TypeScript code snippets and expected results + - Basic expressions and tool calls + - Auto-execute validation scenarios + - Mixed client scenarios + - Edge case scenarios + +- `utils.go` - Test helper functions for assertions and validation + - `createToolCall()` - Creates tool call messages + - `assertExecutionResult()` - Validates execution results + - `assertAgentModeResponse()` - Validates agent mode response structure + - `extractExecutedToolResults()` - Extracts executed tool results from agent mode response + - `canAutoExecuteTool()` - Checks if a tool can be auto-executed + - `createMCPClientConfig()` - Creates MCP client configs + +## Running Tests + +### Run all tests: +```bash +cd tests/core-mcp +go test -v ./... +``` + +### Run specific test file: +```bash +go test -v -run TestClientConfig ./... +``` + +### Run specific test: +```bash +go test -v -run TestSingleCodeModeClient +``` + +### Run with coverage: +```bash +go test -v -cover ./... +``` + +### Run tests by category: +```bash +# Client configuration tests +go test -v -run "^Test.*Client.*" ./... + +# Tool execution tests +go test -v -run "^Test.*Tool.*" ./... + +# Auto-execute tests +go test -v -run "^Test.*Auto.*" ./... + +# Edge case tests +go test -v -run "^Test.*Error|^Test.*Timeout|^Test.*Empty" ./... + +# Integration tests +go test -v -run "^Test.*Workflow|^Test.*Integration" ./... +``` + +## Test Tools + +The test suite registers several test tools: + +1. **echo** - Simple echo that returns input +2. **add** - Adds two numbers +3. **multiply** - Multiplies two numbers +4. **get_data** - Returns structured data (object/array) +5. **error_tool** - Tool that always returns an error +6. **slow_tool** - Tool that takes time to execute +7. **complex_args_tool** - Tool that accepts complex nested arguments + +## Key Test Scenarios + +### Scenario 1: Mixed Auto-Execute and Non-Auto-Execute Tools (Critical) + +When agent mode receives 5 tool calls: 3 auto-execute, 2 non-auto-execute: +- Agent executes the 3 auto-execute tools +- Adds their results to message content (JSON formatted) +- Includes the 2 non-auto-execute tool calls in `tool_calls` field +- Sets `finish_reason` to "stop" +- Stops immediately (no further LLM call) +- Response structure validated correctly + +### Scenario 2: Code Mode Client + Auto-Execute Tools + +- Setup: Code mode client with tools configured for auto-execute +- Test: `executeToolCode` with code calling these tools should auto-execute in agent mode + +### Scenario 3: Mixed Client Types + +- Setup: One code mode client + one non-code mode client +- Test: Code mode tools only see code mode client, non-code mode tools available separately + +### Scenario 4: Auto-Execute Validation in Code + +- Setup: Code mode client with mixed auto-execute config +- Test: `executeToolCode` validates all tool calls in code against auto-execute list + +### Scenario 5: Code Mode Tools Always Auto-Execute + +- Setup: Code mode enabled +- Test: `listToolFiles` and `readToolFile` always auto-execute regardless of config + +## Notes + +- All tests use a timeout context to prevent hanging +- Tests are designed to be independent and can run in parallel +- The test suite uses the `bifrostInternal` server for local tool registration +- Code mode tests verify that TypeScript code is transpiled and executes correctly in the sandboxed goja VM +- TypeScript compilation errors are caught and reported with helpful hints +- Async/await syntax is automatically transpiled to Promise chains compatible with goja +- Error handling tests verify that helpful error hints are provided for both runtime and TypeScript compilation errors +- Agent mode tests verify the critical mixed auto-execute/non-auto-execute scenario where some tools are executed and others are returned for user approval diff --git a/tests/core-mcp/agent_mode_test.go b/tests/core-mcp/agent_mode_test.go new file mode 100644 index 000000000..8f3b00453 --- /dev/null +++ b/tests/core-mcp/agent_mode_test.go @@ -0,0 +1,77 @@ +package mcp + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Note: Full agent mode testing requires integration with LLM calls. +// These tests verify the configuration and tool execution aspects that can be tested directly. +// For full agent mode flow testing, see integration_test.go + +// TestAgentModeConfiguration tests the configuration aspects of agent mode +// Full agent mode flow testing requires LLM integration (see integration_test.go) +func TestAgentModeConfiguration(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Test configuration: echo auto-execute, add non-auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // Only echo is auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + + // Verify configuration + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("add", bifrostClient.Config), "add should not be auto-executable") + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should not be auto-executable") +} + +func TestAgentModeMaxDepthConfiguration(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Create Bifrost with max depth of 2 + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 2, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx context.Context) string { + return "test-request-id" + }, + } + b, err := setupTestBifrostWithMCPConfig(ctx, mcpConfig) + require.NoError(t, err) + + // Verify max depth is configured + clients, err := b.GetMCPClients() + require.NoError(t, err) + assert.NotNil(t, clients, "Should have clients") +} diff --git a/tests/core-mcp/auto_execute_config_test.go b/tests/core-mcp/auto_execute_config_test.go new file mode 100644 index 000000000..a0bff19d5 --- /dev/null +++ b/tests/core-mcp/auto_execute_config_test.go @@ -0,0 +1,322 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToolInToolsToExecuteButNotInToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure echo in ToolsToExecute but not in ToolsToAutoExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo"}, + ToolsToAutoExecute: []string{}, // Empty - no auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Empty(t, bifrostClient.Config.ToolsToAutoExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") +} + +func TestToolInBothToolsToExecuteAndToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure echo in both lists + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo"}, + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Contains(t, bifrostClient.Config.ToolsToAutoExecute, "echo") + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") +} + +func TestToolInToolsToAutoExecuteButNotInToolsToExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure echo in ToolsToAutoExecute but not in ToolsToExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"add"}, // echo not in this list + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + // echo should not be auto-executable because it's not in ToolsToExecute + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (not in ToolsToExecute)") +} + +func TestWildcardInToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure wildcard in ToolsToAutoExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToAutoExecute, "*") + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable with wildcard") + assert.True(t, canAutoExecuteTool("add", bifrostClient.Config), "add should be auto-executable with wildcard") +} + +func TestEmptyToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure empty ToolsToAutoExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{}, // Empty - no auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Empty(t, bifrostClient.Config.ToolsToAutoExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") +} + +func TestNilToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure nil ToolsToAutoExecute (omitted) + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + // ToolsToAutoExecute omitted (nil) + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + // nil should be treated as empty + if bifrostClient.Config.ToolsToAutoExecute == nil { + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (nil treated as empty)") + } else { + assert.Empty(t, bifrostClient.Config.ToolsToAutoExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") + } +} + +func TestMultipleToolsWithMixedAutoExecuteConfigs(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure mixed: echo auto-execute, add non-auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo", "add", "multiply"}, + ToolsToAutoExecute: []string{"echo", "multiply"}, // add not in auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("add", bifrostClient.Config), "add should not be auto-executable") + assert.True(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should be auto-executable") +} + +func TestToolsToExecuteEmptyList(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure empty ToolsToExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{}, // Empty - no tools allowed + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Empty(t, bifrostClient.Config.ToolsToExecute) + // Even with wildcard in ToolsToAutoExecute, tools not in ToolsToExecute should not be auto-executable + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (not in ToolsToExecute)") +} + +func TestToolsToExecuteNil(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure nil ToolsToExecute (omitted) + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + // ToolsToExecute omitted (nil) + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + // nil ToolsToExecute should be treated as empty + if bifrostClient.Config.ToolsToExecute == nil { + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (nil ToolsToExecute treated as empty)") + } else { + assert.Empty(t, bifrostClient.Config.ToolsToExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") + } +} diff --git a/tests/core-mcp/client_config_test.go b/tests/core-mcp/client_config_test.go new file mode 100644 index 000000000..7b7b9851d --- /dev/null +++ b/tests/core-mcp/client_config_test.go @@ -0,0 +1,346 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSingleCodeModeClient(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // Find bifrostInternal client + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient, "bifrostInternal client should exist") + assert.True(t, bifrostClient.Config.IsCodeModeClient, "bifrostInternal should be code mode client") + assert.Equal(t, schemas.MCPConnectionStateConnected, bifrostClient.State) +} + +func TestSingleNonCodeModeClient(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Note: For in-process clients, we need to register tools first + err = registerTestTools(b) + require.NoError(t, err) + + // Update bifrostInternal to be non-code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.False(t, bifrostClient.Config.IsCodeModeClient, "bifrostInternal should be non-code mode client") +} + +func TestMultipleCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + codeModeCount := 0 + for _, client := range clients { + if client.Config.IsCodeModeClient { + codeModeCount++ + } + } + + assert.GreaterOrEqual(t, codeModeCount, 1, "Should have at least one code mode client") +} + +func TestMultipleNonCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + nonCodeModeCount := 0 + for _, client := range clients { + if !client.Config.IsCodeModeClient { + nonCodeModeCount++ + } + } + + assert.GreaterOrEqual(t, nonCodeModeCount, 1, "Should have at least one non-code mode client") +} + +func TestMixedCodeModeAndNonCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + codeModeCount := 0 + + for _, client := range clients { + if client.Config.IsCodeModeClient { + codeModeCount++ + } + } + + // At minimum, we should have bifrostInternal as code mode + assert.GreaterOrEqual(t, codeModeCount, 1, "Should have at least one code mode client") +} + +func TestClientConnectionStates(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // All clients should be connected + for _, client := range clients { + assert.Equal(t, schemas.MCPConnectionStateConnected, client.State, "Client %s should be connected", client.Config.ID) + } +} + +func TestClientWithNoTools(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Don't register any tools - bifrostInternal client should still exist but with no tools + clients, err := b.GetMCPClients() + require.NoError(t, err) + + // bifrostInternal client is created when MCP is initialized, but won't have tools until registered + // This test verifies the client exists even without tools + assert.NotNil(t, clients, "Clients list should exist") + + // Find bifrostInternal client + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient, "bifrostInternal client should exist") + assert.Empty(t, bifrostClient.Tools, "bifrostInternal client should have no tools") +} + +func TestClientWithEmptyToolLists(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set ToolsToExecute to empty list + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Equal(t, []string{}, bifrostClient.Config.ToolsToExecute, "ToolsToExecute should be empty") +} + +func TestClientConfigUpdate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Initially, bifrostInternal should not be code mode (default) + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + initialIsCodeMode := bifrostClient.Config.IsCodeModeClient + + // Update to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + }) + require.NoError(t, err) + + // Verify update + clients, err = b.GetMCPClients() + require.NoError(t, err) + + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.NotEqual(t, initialIsCodeMode, bifrostClient.Config.IsCodeModeClient, "IsCodeModeClient should have changed") + assert.True(t, bifrostClient.Config.IsCodeModeClient, "Should now be code mode") +} + +func TestClientWithToolsToExecuteWildcard(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set ToolsToExecute to wildcard + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "*", "Should contain wildcard") +} + +func TestClientWithSpecificToolsToExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set ToolsToExecute to specific tools + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo", "add"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "add") + assert.Len(t, bifrostClient.Config.ToolsToExecute, 2) +} diff --git a/tests/core-mcp/codemode_auto_execute_test.go b/tests/core-mcp/codemode_auto_execute_test.go new file mode 100644 index 000000000..d73c68cc4 --- /dev/null +++ b/tests/core-mcp/codemode_auto_execute_test.go @@ -0,0 +1,233 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +func TestExecuteToolCodeWithAutoExecuteTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Configure echo as auto-execute - preserve existing config + clients, err := b.GetMCPClients() + require.NoError(t, err) + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + require.NotNil(t, currentConfig) + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: currentConfig.IsCodeModeClient, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + // Test executeToolCode with code calling auto-execute tool + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithAutoExecuteTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithNonAutoExecuteTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Configure multiply as non-auto-execute - preserve existing config + clients, err := b.GetMCPClients() + require.NoError(t, err) + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + require.NotNil(t, currentConfig) + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: currentConfig.IsCodeModeClient, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // multiply not in auto-execute + }) + require.NoError(t, err) + + // Test executeToolCode with code calling non-auto-execute tool + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithNonAutoExecuteTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithMixedAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Configure echo as auto-execute, multiply as non-auto-execute - preserve existing config + clients, err := b.GetMCPClients() + require.NoError(t, err) + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + require.NotNil(t, currentConfig) + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: currentConfig.IsCodeModeClient, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // multiply not in auto-execute + }) + require.NoError(t, err) + + // Test executeToolCode with code calling mixed tools + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithMixedAutoExecute, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithNoToolCalls(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode with no tool calls + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithNoToolCalls, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithListToolFiles(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // listToolFiles should always be auto-executable + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithListToolFiles, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // listToolFiles and readToolFile are code mode meta-tools and cannot be called from within executeToolCode + // They're only available as direct tool calls, not from within code execution + // So this will fail with a runtime error + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestExecuteToolCodeWithReadToolFile(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // readToolFile should always be auto-executable + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithReadToolFile, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // listToolFiles and readToolFile are code mode meta-tools and cannot be called from within executeToolCode + // They're only available as direct tool calls, not from within code execution + // So this will fail with a runtime error + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestExecuteToolCodeWithUndefinedServer(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode with undefined server + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithUndefinedServer, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestExecuteToolCodeWithUndefinedTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode with undefined tool + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithUndefinedTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} diff --git a/tests/core-mcp/edge_cases_test.go b/tests/core-mcp/edge_cases_test.go new file mode 100644 index 000000000..dfc3e780c --- /dev/null +++ b/tests/core-mcp/edge_cases_test.go @@ -0,0 +1,299 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCodeModeClientCallingNonCodeModeClientTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code trying to call non-code mode client tool + // This should fail at runtime since non-code mode clients aren't available in code execution + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingNonCodeModeTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error - tool call succeeds but code execution fails + requireNoBifrostError(t, bifrostErr, "Tool call should succeed") + require.NotNil(t, result, "Result should be present") + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestNonCodeModeClientToolCalledFromExecuteToolCode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Code mode can only call code mode client tools + // Non-code mode tools are not available in executeToolCode context + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": `const result = await NonExistentClient.tool({}); return result`, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error - tool call succeeds but code execution fails + requireNoBifrostError(t, bifrostErr, "Tool call should succeed") + require.NotNil(t, result, "Result should be present") + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestToolNotInToolsToExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure only echo in ToolsToExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo"}, // add not in list + }) + require.NoError(t, err) + + // Try to execute add tool (not in ToolsToExecute) + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(1), + "b": float64(2), + }) + _, bifrostErr := b.ExecuteMCPTool(ctx, addCall) + + // Should fail - tool not available + assert.NotNil(t, bifrostErr, "Should fail when tool not in ToolsToExecute") +} + +func TestToolExecutionTimeoutEdgeCase(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Test slow tool with timeout + slowCall := createToolCall("slow_tool", map[string]interface{}{ + "delay_ms": float64(100), + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, slowCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Completed", "Should complete execution") +} + +func TestToolExecutionErrorPropagation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Test error tool + errorCall := createToolCall("error_tool", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, errorCall) + + // Tool execution should succeed (no bifrostErr), but result should contain error message + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Error:", "Result should contain error message") + assert.Contains(t, responseText, "this tool always fails", "Result should contain the error text") +} + +func TestEmptyCodeExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.EmptyCode, + }) + + _, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Empty code should return an error + require.NotNil(t, bifrostErr, "Empty code should return an error") + assert.Contains(t, bifrostErr.Error.Message, "code parameter is required", "Error should mention code parameter") +} + +func TestCodeWithSyntaxErrors(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.SyntaxError, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // Syntax errors are caught during JavaScript execution (runtime), not TypeScript compilation + // The error will be a runtime SyntaxError + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestCodeWithTypeScriptCompilationErrors(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Invalid TypeScript code + invalidCode := `const x: string = 123; return x` + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": invalidCode, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // TypeScript type errors might not be caught - the code might execute successfully + // This is acceptable behavior if type checking is disabled + // Just verify the execution completed (either with error or success) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) +} + +func TestCodeWithRuntimeErrors(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.RuntimeError, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestCodeCallingToolsWithInvalidArguments(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Code calling tool with invalid arguments + invalidArgsCode := `const result = await BifrostClient.echo({invalid: "arg"}); return result` + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": invalidArgsCode, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail - tool expects "message" parameter + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "") +} + +func TestCodeModeToolsAlwaysAutoExecutable(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{}, // Empty - no auto-execute configured + }) + require.NoError(t, err) + + // listToolFiles and readToolFile should always be auto-executable + // This is tested in integration tests that verify agent mode behavior + // For now, verify they can be executed directly + listCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) +} + +func TestCommentsOnlyCode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CommentsOnly, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // Comments-only code should execute (return null) + assertExecutionResult(t, result, true, nil, "") +} + +func TestUndefinedVariableError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.UndefinedVariable, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} diff --git a/tests/core-mcp/fixtures.go b/tests/core-mcp/fixtures.go new file mode 100644 index 000000000..fe8b5a82e --- /dev/null +++ b/tests/core-mcp/fixtures.go @@ -0,0 +1,311 @@ +package mcp + +// CodeFixtures contains sample TypeScript code snippets for testing +var CodeFixtures = struct { + // Basic expressions + SimpleExpression string + SimpleString string + VariableAssignment string + ConsoleLogging string + ExplicitReturn string + AutoReturnExpression string + + // MCP tool calls + SingleToolCall string + ToolCallWithPromise string + ToolCallChain string + ToolCallErrorHandling string + MultipleServerToolCalls string + ToolCallWithComplexArgs string + + // Import/Export + ImportStatement string + ExportStatement string + MultipleImportExport string + ImportExportWithComments string + + // Expression analysis + FunctionCallExpression string + PromiseChainExpression string + ObjectLiteralExpression string + AssignmentStatement string + ControlFlowStatement string + TopLevelReturn string + + // Error cases + UndefinedVariable string + UndefinedServer string + UndefinedTool string + SyntaxError string + RuntimeError string + + // Edge cases + NestedPromiseChains string + PromiseErrorHandling string + ComplexDataStructures string + MultiLineExpression string + EmptyCode string + CommentsOnly string + FunctionDefinition string + + // Environment tests + AsyncAwaitTest string + EnvironmentTest string + + // Long code test + LongCodeExecution string + + // Auto-execute validation tests + CodeWithAutoExecuteTool string + CodeWithNonAutoExecuteTool string + CodeWithMixedAutoExecute string + CodeWithMultipleClients string + CodeWithNoToolCalls string + CodeWithListToolFiles string + CodeWithReadToolFile string + + // Mixed client scenarios + CodeCallingCodeModeTool string + CodeCallingNonCodeModeTool string + CodeCallingMultipleServers string + CodeWithUndefinedServer string + CodeWithUndefinedTool string + + // Agent mode scenarios + CodeForAgentModeAutoExecute string + CodeForAgentModeNonAutoExecute string +}{ + SimpleExpression: `return 1 + 1`, + SimpleString: `return "hello"`, + VariableAssignment: `var x = 5; return x`, + ConsoleLogging: `console.log("test"); return "logged"`, + ExplicitReturn: `return 42`, + AutoReturnExpression: `return 2 + 2`, // Note: Now requires explicit return + + SingleToolCall: `const result = await BifrostClient.echo({message: "hello"}); return result`, + ToolCallWithPromise: `const result = await BifrostClient.echo({message: "test"}); console.log(result); return result`, + ToolCallChain: `const result1 = await BifrostClient.add({a: 1, b: 2}); const result2 = await BifrostClient.multiply({a: result1, b: 3}); return result2`, + ToolCallErrorHandling: `try { await BifrostClient.error_tool({}); } catch (err) { console.error(err); return "handled"; }`, + MultipleServerToolCalls: `const r1 = await BifrostClient.echo({message: "test"}); const r2 = await BifrostClient.add({a: 1, b: 2}); return r2`, + ToolCallWithComplexArgs: `return await BifrostClient.complex_args_tool({data: {nested: {value: 42}}})`, + + ImportStatement: `import { something } from "module"; return 1 + 1`, + ExportStatement: `export const x = 5; return x`, + MultipleImportExport: `import a from "a"; import b from "b"; export const c = 1; return 2 + 2`, + ImportExportWithComments: `// comment\nimport x from "x";\n// another comment\nreturn 2 + 2`, + + FunctionCallExpression: `return Math.max(1, 2)`, // Note: Now requires explicit return + PromiseChainExpression: `return Promise.resolve(1).then(x => x + 1)`, // Note: Now requires explicit return + ObjectLiteralExpression: `return {a: 1, b: 2}`, // Note: Now requires explicit return + AssignmentStatement: `var x = 5`, // Assignment statements don't return values + ControlFlowStatement: `if (true) { return 1; } else { return 2; }`, // Note: Now requires explicit return + TopLevelReturn: `return 42`, + + UndefinedVariable: `return undefinedVar`, // Will cause runtime error + UndefinedServer: `return nonexistentServer.tool({})`, // Will cause runtime error + UndefinedTool: `return BifrostClient.nonexistentTool({})`, // Will cause runtime error + SyntaxError: `var x = `, // Syntax error - no return needed + RuntimeError: `return null.someProperty`, // Will cause runtime error + + NestedPromiseChains: `return Promise.resolve(1).then(x => Promise.resolve(x + 1).then(y => y + 1))`, // Note: Now requires explicit return + PromiseErrorHandling: `return Promise.reject("error").catch(err => "handled")`, // Note: Now requires explicit return + ComplexDataStructures: `return [{a: 1}, {b: 2}].map(x => x.a || x.b)`, // Note: Now requires explicit return + MultiLineExpression: `const result = await BifrostClient.echo({message: "test"});\n return result`, // Note: Now requires explicit return + EmptyCode: ``, + CommentsOnly: `// comment\n/* another */`, + FunctionDefinition: `function test() { return 1; } return test()`, // Note: Now requires explicit return for function call + + AsyncAwaitTest: `async function test() { const result = await Promise.resolve(1); return result; } return test()`, + EnvironmentTest: `return __MCP_ENV__.serverKeys`, + + LongCodeExecution: `// Long and complex code execution test with extensive operations\n` + + `(async function() {\n` + + ` var results = [];\n` + + ` var sum = 0;\n` + + ` var processedData = [];\n` + + ` var executionLog = [];\n` + + ` \n` + + ` // Initialize execution context\n` + + ` var context = {\n` + + ` startTime: Date.now(),\n` + + ` steps: 0,\n` + + ` errors: [],\n` + + ` warnings: []\n` + + ` };\n` + + ` \n` + + ` try {\n` + + ` // Step 1: Initial echo call\n` + + ` const result1 = await BifrostClient.echo({message: "step1"});\n` + + ` console.log("Step 1 completed:", result1);\n` + + ` results.push(result1);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 1, action: "echo", result: result1});\n` + + ` \n` + + ` // Step 2: Add operation\n` + + ` const result2 = await BifrostClient.add({a: 10, b: 20});\n` + + ` console.log("Step 2 completed:", result2);\n` + + ` results.push(result2);\n` + + ` sum += result2;\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 2, action: "add", result: result2, sum: sum});\n` + + ` \n` + + ` // Conditional logic based on result\n` + + ` let result3;\n` + + ` if (result2 > 25) {\n` + + ` console.log("Result is greater than 25, proceeding with multiplication");\n` + + ` result3 = await BifrostClient.multiply({a: result2, b: 2});\n` + + ` } else {\n` + + ` console.log("Result is less than or equal to 25, using add again");\n` + + ` result3 = await BifrostClient.add({a: result2, b: 5});\n` + + ` }\n` + + ` console.log("Step 3 completed:", result3);\n` + + ` results.push(result3);\n` + + ` sum += result3;\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 3, action: "math", result: result3, sum: sum});\n` + + ` \n` + + ` // Step 4: Echo call\n` + + ` const result4 = await BifrostClient.echo({message: "step4"});\n` + + ` console.log("Step 4 completed:", result4);\n` + + ` results.push(result4);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 4, action: "echo", result: result4});\n` + + ` \n` + + ` // Complex loop with nested operations\n` + + ` for (var i = 0; i < 20; i++) {\n` + + ` sum += i;\n` + + ` if (i % 3 === 0) {\n` + + ` processedData.push({\n` + + ` index: i,\n` + + ` value: i * 2,\n` + + ` isMultipleOfThree: true\n` + + ` });\n` + + ` } else if (i % 2 === 0) {\n` + + ` processedData.push({\n` + + ` index: i,\n` + + ` value: i * 1.5,\n` + + ` isEven: true\n` + + ` });\n` + + ` } else {\n` + + ` processedData.push({\n` + + ` index: i,\n` + + ` value: i,\n` + + ` isOdd: true\n` + + ` });\n` + + ` }\n` + + ` }\n` + + ` \n` + + ` console.log("Processed", processedData.length, "data items");\n` + + ` \n` + + ` // Step 5: Get data\n` + + ` const result5 = await BifrostClient.get_data({key: "test"});\n` + + ` console.log("Step 5 completed:", result5);\n` + + ` results.push(result5);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 5, action: "get_data", result: result5});\n` + + ` \n` + + ` // Nested data processing\n` + + ` var nestedResults = [];\n` + + ` for (var j = 0; j < results.length; j++) {\n` + + ` var item = results[j];\n` + + ` nestedResults.push({\n` + + ` original: item,\n` + + ` processed: typeof item === "string" ? item.toUpperCase() : item * 1.1,\n` + + ` index: j,\n` + + ` metadata: {\n` + + ` type: typeof item,\n` + + ` isString: typeof item === "string",\n` + + ` isNumber: typeof item === "number"\n` + + ` }\n` + + ` });\n` + + ` }\n` + + ` \n` + + ` // Step 6: Final echo call\n` + + ` const result6 = await BifrostClient.echo({message: "final_step"});\n` + + ` console.log("Step 6 completed:", result6);\n` + + ` results.push(result6);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 6, action: "echo", result: result6});\n` + + ` \n` + + ` // Calculate statistics\n` + + ` var stats = {\n` + + ` totalResults: results.length,\n` + + ` numericSum: sum,\n` + + ` average: sum / results.length,\n` + + ` processedItems: processedData.length,\n` + + ` executionSteps: context.steps\n` + + ` };\n` + + ` \n` + + ` // Create comprehensive final data structure\n` + + ` var finalData = {\n` + + ` results: results,\n` + + ` processedData: processedData,\n` + + ` executionLog: executionLog,\n` + + ` statistics: stats,\n` + + ` context: {\n` + + ` steps: context.steps,\n` + + ` executionTime: Date.now() - context.startTime,\n` + + ` errors: context.errors,\n` + + ` warnings: context.warnings\n` + + ` },\n` + + ` metadata: {\n` + + ` executed: true,\n` + + ` completed: true,\n` + + ` totalOperations: context.steps,\n` + + ` dataProcessed: processedData.length,\n` + + ` finalSum: sum,\n` + + ` resultCount: results.length\n` + + ` }\n` + + ` };\n` + + ` \n` + + ` console.log("Final statistics:", JSON.stringify(stats));\n` + + ` console.log("Execution completed successfully with", context.steps, "steps");\n` + + ` console.log("Processed", processedData.length, "data items");\n` + + ` console.log("Final sum:", sum);\n` + + ` \n` + + ` return finalData;\n` + + ` } catch (error) {\n` + + ` console.error("Error in long execution:", error);\n` + + ` context.errors.push(error.toString());\n` + + ` return {\n` + + ` error: error.toString(),\n` + + ` context: context,\n` + + ` partialResults: results,\n` + + ` partialSum: sum\n` + + ` };\n` + + ` }\n` + + `})()`, + + // Auto-execute validation tests + CodeWithAutoExecuteTool: `const result = await BifrostClient.echo({message: "auto-execute"}); return result`, + CodeWithNonAutoExecuteTool: `const result = await BifrostClient.multiply({a: 2, b: 3}); return result`, + CodeWithMixedAutoExecute: `const r1 = await BifrostClient.echo({message: "auto"}); const r2 = await BifrostClient.multiply({a: 2, b: 3}); return r2`, + CodeWithMultipleClients: `const r1 = await BifrostClient.echo({message: "test"}); const r2 = await Server2.add({a: 1, b: 2}); return r2`, + CodeWithNoToolCalls: `return 42`, + CodeWithListToolFiles: `const files = await BifrostClient.listToolFiles({}); return files`, + CodeWithReadToolFile: `const content = await BifrostClient.readToolFile({fileName: "BifrostClient.d.ts"}); return content`, + + // Mixed client scenarios + CodeCallingCodeModeTool: `const result = await BifrostClient.echo({message: "test"}); return result`, + CodeCallingNonCodeModeTool: `const result = await NonCodeModeClient.someTool({}); return result`, + CodeCallingMultipleServers: `const r1 = await BifrostClient.echo({message: "test"}); const r2 = await Server2.add({a: 1, b: 2}); return {r1, r2}`, + CodeWithUndefinedServer: `const result = await UndefinedServer.tool({}); return result`, + CodeWithUndefinedTool: `const result = await BifrostClient.undefinedTool({}); return result`, + + // Agent mode scenarios + CodeForAgentModeAutoExecute: `const result = await BifrostClient.echo({message: "agent-auto"}); return result`, + CodeForAgentModeNonAutoExecute: `const result = await BifrostClient.multiply({a: 5, b: 6}); return result`, +} + +// ExpectedResults contains expected results for validation +var ExpectedResults = struct { + SimpleExpressionResult interface{} + EchoResult string + AddResult float64 + MultiplyResult float64 +}{ + SimpleExpressionResult: float64(2), + EchoResult: "hello", + AddResult: float64(3), + MultiplyResult: float64(6), +} diff --git a/tests/core-mcp/go.mod b/tests/core-mcp/go.mod new file mode 100644 index 000000000..2d10ebcb6 --- /dev/null +++ b/tests/core-mcp/go.mod @@ -0,0 +1,63 @@ +module github.com/maximhq/bifrost/tests/core-mcp + +go 1.24.3 + +replace github.com/maximhq/bifrost/core => ../../core + +require ( + github.com/maximhq/bifrost/core v0.0.0-00010101000000-000000000000 + github.com/stretchr/testify v1.11.1 +) + +require ( + cloud.google.com/go/compute/metadata v0.9.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.39.5 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.13 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect + github.com/aws/smithy-go v1.23.1 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.1 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.1 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.1 // indirect + github.com/mark3labs/mcp-go v0.41.1 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.67.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.22.0 // indirect + golang.org/x/net v0.47.0 // indirect + golang.org/x/oauth2 v0.32.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/tests/core-mcp/go.sum b/tests/core-mcp/go.sum new file mode 100644 index 000000000..73a452e71 --- /dev/null +++ b/tests/core-mcp/go.sum @@ -0,0 +1,141 @@ +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.39.5 h1:e/SXuia3rkFtapghJROrydtQpfQaaUgd1cUvyO1mp2w= +github.com/aws/aws-sdk-go-v2 v1.39.5/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 h1:t9yYsydLYNBk9cJ73rgPhPWqOh/52fcWDQB5b1JsKSY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2/go.mod h1:IusfVNTmiSN3t4rhxWFaBAqn+mcNdwKtPcV16eYdgko= +github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k= +github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.12 h1:p/9flfXdoAnwJnuW9xHEAFY22R3A6skYkW19JFF9F+8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.12/go.mod h1:ZTLHakoVCTtW8AaLGSwJ3LXqHD9uQKnOcv1TrpO6u2k= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.12 h1:2lTWFvRcnWFFLzHWmtddu5MTchc5Oj2OOey++99tPZ0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.12/go.mod h1:hI92pK+ho8HVcWMHKHrK3Uml4pfG7wvL86FzO0LVtQQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 h1:e+8XbKB6IMn8A4OAyZccO4pYfB3s7bt6azNIPE7AnPg= +github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= +github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= +github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.41.1 h1:w78eWfiQam2i8ICL7AL0WFiq7KHNJQ6UB53ZVtH4KGA= +github.com/mark3labs/mcp-go v0.41.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.67.0 h1:tqKlJMUP6iuNG8hGjK/s9J4kadH7HLV4ijEcPGsezac= +github.com/valyala/fasthttp v1.67.0/go.mod h1:qYSIpqt/0XNmShgo/8Aq8E3UYWVVwNS2QYmzd8WIEPM= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= +golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= +golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tests/core-mcp/integration_test.go b/tests/core-mcp/integration_test.go new file mode 100644 index 000000000..def838b6a --- /dev/null +++ b/tests/core-mcp/integration_test.go @@ -0,0 +1,229 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFullWorkflowListToolFilesReadToolFileExecuteToolCode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Step 1: List tool files + listCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient") + + // Step 2: Read tool file + readCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, readCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Contains(t, responseText, "interface", "Should contain interface definitions") + assert.Contains(t, responseText, "echo", "Should contain echo tool") + + // Step 3: Execute code using the discovered tools + executeCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingCodeModeTool, + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, executeCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestMultipleCodeModeClientsWithDifferentAutoExecuteConfigs(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure bifrostInternal with mixed auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo", "add"}, // multiply not auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config)) + assert.True(t, canAutoExecuteTool("add", bifrostClient.Config)) + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config)) +} + +func TestToolFilteringWithCodeMode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure specific tools only + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"echo", "add"}, // Only these tools available + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "add") + assert.NotContains(t, bifrostClient.Config.ToolsToExecute, "multiply") +} + +func TestCodeModeAndNonCodeModeToolsInSameRequest(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Code mode tools should be available + listCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // Verify direct tools are not exposed for code-mode clients + // Code mode clients expose tools via executeToolCode, not as direct tool calls + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test", + }) + _, bifrostErr = b.ExecuteMCPTool(ctx, echoCall) + require.NotNil(t, bifrostErr, "Direct tool call should fail for code-mode client") + assert.Contains(t, bifrostErr.Error.Message, "not available", "Error should indicate tool is not available") +} + +func TestComplexCodeExecutionWithMultipleToolCalls(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test complex code with multiple tool calls + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.ToolCallChain, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestCodeExecutionWithErrorHandling(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code with error handling + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.ToolCallErrorHandling, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "handled") +} + +func TestCodeExecutionWithAsyncAwait(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test async/await syntax + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.AsyncAwaitTest, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestLongCodeExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test long and complex code execution + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.LongCodeExecution, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} diff --git a/tests/core-mcp/mcp_connection_test.go b/tests/core-mcp/mcp_connection_test.go new file mode 100644 index 000000000..e55354931 --- /dev/null +++ b/tests/core-mcp/mcp_connection_test.go @@ -0,0 +1,299 @@ +package mcp + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMCPManagerInitialization(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + require.NotNil(t, b) + + // Verify MCP is configured + clients, err := b.GetMCPClients() + require.NoError(t, err) + assert.NotNil(t, clients) +} + +func TestLocalToolRegistration(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Verify tools are available + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // Find the bifrostInternal client + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient, "bifrostInternal client should exist") + assert.Equal(t, schemas.MCPConnectionStateConnected, bifrostClient.State) + + // Verify tools are registered + toolNames := make(map[string]bool) + for _, tool := range bifrostClient.Tools { + toolNames[tool.Name] = true + } + + assert.True(t, toolNames["echo"], "echo tool should be registered") + assert.True(t, toolNames["add"], "add tool should be registered") + assert.True(t, toolNames["multiply"], "multiply tool should be registered") +} + +func TestToolDiscovery(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Use CodeMode since we're testing CodeMode tools (listToolFiles, readToolFile) + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test listToolFiles + listToolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listToolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "servers/", "Should list servers") + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient server") + + // Test readToolFile + readToolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, readToolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Contains(t, responseText, "interface", "Should contain TypeScript interface declarations") + assert.Contains(t, responseText, "echo", "Should contain echo tool definition") + assert.Contains(t, responseText, "EchoInput", "Should contain echo input interface") +} + +func TestToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test echo tool + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText) + + // Test add tool + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(5), + "b": float64(3), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Equal(t, "8", responseText) + + // Test multiply tool + multiplyCall := createToolCall("multiply", map[string]interface{}{ + "a": float64(4), + "b": float64(7), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, multiplyCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Equal(t, "28", responseText) +} + +func TestMultipleServers(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Use CodeMode since we're testing CodeMode tools (listToolFiles) + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Verify we have at least one server + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // Test listToolFiles with multiple servers + listToolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listToolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient server") +} + +// TestExternalMCPConnection tests connection to external MCP server +// This test requires external MCP credentials to be provided via environment variables +// or test configuration. For now, it's a placeholder that can be enabled when credentials are available. +func TestExternalMCPConnection(t *testing.T) { + t.Skip("Skipping external MCP connection test - requires credentials") + + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + _, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Example: Connect to external MCP server + // Uncomment and configure when credentials are available + /* + connectionString := os.Getenv("EXTERNAL_MCP_CONNECTION_STRING") + if connectionString == "" { + t.Skip("EXTERNAL_MCP_CONNECTION_STRING not set") + } + + err = connectExternalMCP(b, "external-server", "external-1", "http", connectionString) + require.NoError(t, err) + + // Verify connection + clients := b.GetMCPClients() + found := false + for _, client := range clients { + if client.Config.ID == "external-1" { + found = true + assert.Equal(t, schemas.MCPConnectionStateConnected, client.State) + break + } + } + assert.True(t, found, "External client should be connected") + */ +} + +func TestToolExecutionTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test slow tool with short timeout + slowCall := createToolCall("slow_tool", map[string]interface{}{ + "delay_ms": float64(100), + }) + + start := time.Now() + result, bifrostErr := b.ExecuteMCPTool(ctx, slowCall) + duration := time.Since(start) + + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assert.GreaterOrEqual(t, duration, 100*time.Millisecond, "Should take at least 100ms") +} + +func TestToolExecutionError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test error tool - tool execution succeeds but result contains error message + errorCall := createToolCall("error_tool", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, errorCall) + + // Tool execution should succeed (no bifrostErr), but result should contain error message + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Error:", "Result should contain error message") + assert.Contains(t, responseText, "this tool always fails", "Result should contain the error text") +} + +func TestComplexArgsTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test complex args tool + complexCall := createToolCall("complex_args_tool", map[string]interface{}{ + "data": map[string]interface{}{ + "nested": map[string]interface{}{ + "value": float64(42), + "array": []interface{}{1, 2, 3}, + }, + }, + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, complexCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Received data", "Should process complex args") + assert.Contains(t, responseText, "42", "Should contain nested value") +} diff --git a/tests/core-mcp/responses_test.go b/tests/core-mcp/responses_test.go new file mode 100644 index 000000000..d9c034788 --- /dev/null +++ b/tests/core-mcp/responses_test.go @@ -0,0 +1,442 @@ +package mcp + +import ( + "context" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestResponsesNonCodeModeToolExecution tests direct tool execution via Responses API +func TestResponsesNonCodeModeToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode and ensure tools are available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, // Allow all tools + }) + require.NoError(t, err) + + // Execute tool directly to verify it works + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText, "Echo tool should return the input message") +} + +// TestResponsesCodeModeToolExecution tests code mode tool execution via Responses API +func TestResponsesCodeModeToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode directly to verify code mode works + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.SimpleExpression, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "completed successfully") +} + +// TestResponsesAgentModeWithAutoExecuteTools tests agent mode configuration with auto-executable tools +func TestResponsesAgentModeWithAutoExecuteTools(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure bifrostInternal with echo as auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // Only echo is auto-execute + }) + require.NoError(t, err) + + // Verify configuration + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should not be auto-executable") + + // Verify echo tool can be executed directly + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText, "Echo tool should return the input message") +} + +// TestResponsesAgentModeWithNonAutoExecuteTools tests agent mode configuration with non-auto-executable tools +func TestResponsesAgentModeWithNonAutoExecuteTools(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure bifrostInternal with multiply NOT in auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // multiply is NOT auto-execute + }) + require.NoError(t, err) + + // Verify configuration + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should not be auto-executable") + + // Verify multiply tool can still be executed directly (just not auto-executed) + multiplyCall := createToolCall("multiply", map[string]interface{}{ + "a": float64(2), + "b": float64(3), + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, multiplyCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "6", responseText, "Multiply tool should return correct result") +} + +// TestResponsesAgentModeMaxDepth tests agent mode max depth configuration via Responses API +func TestResponsesAgentModeMaxDepth(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Create Bifrost with max depth of 2 + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 2, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx context.Context) string { + return "test-request-id" + }, + } + b, err := setupTestBifrostWithMCPConfig(ctx, mcpConfig) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure all tools as available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Verify tools still work with max depth configured + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test", responseText, "Echo tool should work with max depth configured") +} + +// TestResponsesToolExecutionTimeout tests tool execution timeout via Responses API +func TestResponsesToolExecutionTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Create Bifrost with short timeout + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 10, + ToolExecutionTimeout: 100 * time.Millisecond, // Very short timeout + }, + FetchNewRequestIDFunc: func(ctx context.Context) string { + return "test-request-id" + }, + } + b, err := setupTestBifrostWithMCPConfig(ctx, mcpConfig) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure slow_tool + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Create a Responses request that will trigger a slow tool + req := &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Call slow_tool with delay 500ms"), + }, + }, + }, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{ + { + Name: schemas.Ptr("slow_tool"), + Description: schemas.Ptr("A tool that takes time to execute"), + }, + }, + }, + } + + // Execute the request - should handle timeout gracefully + _, bifrostErr := b.ResponsesRequest(ctx, req) + // Timeout errors are acceptable in this test + if bifrostErr != nil { + assert.Contains(t, bifrost.GetErrorMessage(bifrostErr), "timeout", "Should contain timeout error") + } +} + +// TestResponsesMultipleToolCalls tests multiple tool calls via Responses API +func TestResponsesMultipleToolCalls(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure all tools as available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Test echo tool + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test", responseText, "Echo tool should return correct result") + + // Test add tool + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(5), + "b": float64(3), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "8", responseText, "Add tool should return correct result") +} + +// TestResponsesCodeModeWithCodeExecution tests code mode with code execution via Responses API +func TestResponsesCodeModeWithCodeExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code calling code mode client tools + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingCodeModeTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "test") +} + +// TestResponsesToolFiltering tests tool filtering via Responses API +func TestResponsesToolFiltering(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure specific tools only + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"echo", "add"}, // Only these tools available + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + // Verify allowed tools work + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test", responseText, "Echo tool should work") + + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(1), + "b": float64(2), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "3", responseText, "Add tool should work") + + // Verify multiply tool is NOT available (should fail) + multiplyCall := createToolCall("multiply", map[string]interface{}{ + "a": float64(2), + "b": float64(3), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, multiplyCall) + // Should fail because multiply is not in ToolsToExecute + assert.NotNil(t, bifrostErr, "Multiply tool should fail when not in ToolsToExecute") +} + +// TestResponsesComplexWorkflow tests a complex workflow via Responses API +func TestResponsesComplexWorkflow(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure all tools as available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Test echo tool + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "hello", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "hello", responseText, "Echo tool should return correct result") + + // Test add tool + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(5), + "b": float64(3), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "8", responseText, "Add tool should return correct result") + + // Test multiply tool with result from add + multiplyCall := createToolCall("multiply", map[string]interface{}{ + "a": float64(8), // Result from add + "b": float64(2), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, multiplyCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "16", responseText, "Multiply tool should return correct result") +} diff --git a/tests/core-mcp/setup.go b/tests/core-mcp/setup.go new file mode 100644 index 000000000..b6e6cdac4 --- /dev/null +++ b/tests/core-mcp/setup.go @@ -0,0 +1,402 @@ +package mcp + +import ( + "context" + "fmt" + "os" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// TestTimeout defines the maximum duration for MCP tests +const TestTimeout = 10 * time.Minute + +// TestAccount is a minimal account implementation for testing +type TestAccount struct{} + +func (a *TestAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil +} + +func (a *TestAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, nil +} + +func (a *TestAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// setupTestBifrost initializes and returns a Bifrost instance for testing +// This creates a basic Bifrost instance without any MCP clients configured +func setupTestBifrost(ctx context.Context) (*bifrost.Bifrost, error) { + return setupTestBifrostWithMCPConfig(ctx, &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 10, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx context.Context) string { + return "test-request-id" + }, + }) +} + +// setupTestBifrostWithCodeMode initializes and returns a Bifrost instance for testing with CodeMode +// This sets up bifrostInternal client as a code mode client +// Note: Tools must be registered first to create the bifrostInternal client +func setupTestBifrostWithCodeMode(ctx context.Context) (*bifrost.Bifrost, error) { + b, err := setupTestBifrost(ctx) + if err != nil { + return nil, err + } + + // Register tools first to create the bifrostInternal client + err = registerTestTools(b) + if err != nil { + return nil, fmt.Errorf("failed to register test tools: %w", err) + } + + // Get current client config to preserve existing settings + clients, err := b.GetMCPClients() + if err != nil { + return nil, fmt.Errorf("failed to get MCP clients: %w", err) + } + + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + + if currentConfig == nil { + return nil, fmt.Errorf("bifrostInternal client not found") + } + + // Set bifrostInternal client to code mode and ensure tools are available + // Preserve existing ToolsToExecute if set, otherwise use wildcard + toolsToExecute := currentConfig.ToolsToExecute + if len(toolsToExecute) == 0 { + toolsToExecute = []string{"*"} + } + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: true, + ToolsToExecute: toolsToExecute, + ToolsToAutoExecute: currentConfig.ToolsToAutoExecute, + }) + if err != nil { + return nil, fmt.Errorf("failed to set bifrostInternal client to code mode: %w", err) + } + + return b, nil +} + +// setupTestBifrostWithMCPConfig initializes Bifrost with custom MCP config +func setupTestBifrostWithMCPConfig(ctx context.Context, mcpConfig *schemas.MCPConfig) (*bifrost.Bifrost, error) { + account := &TestAccount{} + + // Ensure FetchNewRequestIDFunc is set if not provided + // This is required for the tools handler to be fully setup + if mcpConfig.FetchNewRequestIDFunc == nil { + mcpConfig.FetchNewRequestIDFunc = func(ctx context.Context) string { + return "test-request-id" + } + } + + if mcpConfig.ToolManagerConfig == nil { + mcpConfig.ToolManagerConfig = &schemas.MCPToolManagerConfig{ + MaxAgentDepth: schemas.DefaultMaxAgentDepth, + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + } + } + + b, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: account, + Plugins: nil, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), + MCPConfig: mcpConfig, + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize Bifrost: %w", err) + } + + return b, nil +} + +// registerTestTools registers simple test tools for testing +func registerTestTools(b *bifrost.Bifrost) error { + // Echo tool + echoSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "echo", + Description: schemas.Ptr("Echoes back the input message"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to echo", + }, + }, + Required: []string{"message"}, + }, + }, + } + if err := b.RegisterMCPTool("echo", "Echoes back the input message", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + message, ok := argsMap["message"].(string) + if !ok { + return "", fmt.Errorf("message field is required") + } + return message, nil + }, echoSchema); err != nil { + return fmt.Errorf("failed to register echo tool: %w", err) + } + + // Add tool + addSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "add", + Description: schemas.Ptr("Adds two numbers"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "a": map[string]interface{}{ + "type": "number", + "description": "First number", + }, + "b": map[string]interface{}{ + "type": "number", + "description": "Second number", + }, + }, + Required: []string{"a", "b"}, + }, + }, + } + if err := b.RegisterMCPTool("add", "Adds two numbers", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + a, ok := argsMap["a"].(float64) + if !ok { + return "", fmt.Errorf("a field is required") + } + bVal, ok := argsMap["b"].(float64) + if !ok { + return "", fmt.Errorf("b field is required") + } + return fmt.Sprintf("%.0f", a+bVal), nil + }, addSchema); err != nil { + return fmt.Errorf("failed to register add tool: %w", err) + } + + // Multiply tool + multiplySchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "multiply", + Description: schemas.Ptr("Multiplies two numbers"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "a": map[string]interface{}{ + "type": "number", + "description": "First number", + }, + "b": map[string]interface{}{ + "type": "number", + "description": "Second number", + }, + }, + Required: []string{"a", "b"}, + }, + }, + } + if err := b.RegisterMCPTool("multiply", "Multiplies two numbers", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + a, ok := argsMap["a"].(float64) + if !ok { + return "", fmt.Errorf("a field is required") + } + bVal, ok := argsMap["b"].(float64) + if !ok { + return "", fmt.Errorf("b field is required") + } + return fmt.Sprintf("%.0f", a*bVal), nil + }, multiplySchema); err != nil { + return fmt.Errorf("failed to register multiply tool: %w", err) + } + + // GetData tool - returns structured data + getDataSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_data", + Description: schemas.Ptr("Returns structured data"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{}, + Required: []string{}, + }, + }, + } + if err := b.RegisterMCPTool("get_data", "Returns structured data", func(args any) (string, error) { + return `{"items": [{"id": 1, "name": "test"}, {"id": 2, "name": "example"}]}`, nil + }, getDataSchema); err != nil { + return fmt.Errorf("failed to register get_data tool: %w", err) + } + + // ErrorTool - always returns an error + errorToolSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "error_tool", + Description: schemas.Ptr("A tool that always returns an error"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{}, + Required: []string{}, + }, + }, + } + if err := b.RegisterMCPTool("error_tool", "A tool that always returns an error", func(args any) (string, error) { + return "", fmt.Errorf("this tool always fails") + }, errorToolSchema); err != nil { + return fmt.Errorf("failed to register error_tool: %w", err) + } + + // SlowTool - takes time to execute + slowToolSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "slow_tool", + Description: schemas.Ptr("A tool that takes time to execute"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "delay_ms": map[string]interface{}{ + "type": "number", + "description": "Delay in milliseconds", + }, + }, + Required: []string{"delay_ms"}, + }, + }, + } + if err := b.RegisterMCPTool("slow_tool", "A tool that takes time to execute", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + delayMs, ok := argsMap["delay_ms"].(float64) + if !ok { + return "", fmt.Errorf("delay_ms field is required") + } + time.Sleep(time.Duration(delayMs) * time.Millisecond) + return fmt.Sprintf("Completed after %v ms", delayMs), nil + }, slowToolSchema); err != nil { + return fmt.Errorf("failed to register slow_tool: %w", err) + } + + // ComplexArgsTool - accepts complex nested arguments + complexArgsSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "complex_args_tool", + Description: schemas.Ptr("A tool that accepts complex nested arguments"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "data": map[string]interface{}{ + "type": "object", + "description": "Complex nested data", + }, + }, + Required: []string{"data"}, + }, + }, + } + if err := b.RegisterMCPTool("complex_args_tool", "A tool that accepts complex nested arguments", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + data, ok := argsMap["data"] + if !ok { + return "", fmt.Errorf("data field is required") + } + return fmt.Sprintf("Received data: %v", data), nil + }, complexArgsSchema); err != nil { + return fmt.Errorf("failed to register complex_args_tool: %w", err) + } + + return nil +} + +// connectExternalMCP connects to an external MCP server +// This is a helper function that can be used when external MCP credentials are provided +func connectExternalMCP(b *bifrost.Bifrost, name, id, connectionType, connectionString string) error { + var clientConfig schemas.MCPClientConfig + + switch connectionType { + case "http": + clientConfig = schemas.MCPClientConfig{ + ID: id, + Name: name, + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: schemas.Ptr(connectionString), + } + case "sse": + clientConfig = schemas.MCPClientConfig{ + ID: id, + Name: name, + ConnectionType: schemas.MCPConnectionTypeSSE, + ConnectionString: schemas.Ptr(connectionString), + } + default: + return fmt.Errorf("unsupported connection type: %s", connectionType) + } + + clients, err := b.GetMCPClients() + if err != nil { + return fmt.Errorf("failed to get MCP clients: %w", err) + } + for _, client := range clients { + if client.Config.ID == id { + // Client already exists + return nil + } + } + + if err := b.AddMCPClient(clientConfig); err != nil { + return fmt.Errorf("failed to add external MCP client: %w", err) + } + + return nil +} diff --git a/tests/core-mcp/tool_execution_test.go b/tests/core-mcp/tool_execution_test.go new file mode 100644 index 000000000..5f3467805 --- /dev/null +++ b/tests/core-mcp/tool_execution_test.go @@ -0,0 +1,246 @@ +package mcp + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNonCodeModeToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode and ensure tools are available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, // Allow all tools + }) + require.NoError(t, err) + + // Test direct tool execution + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText) +} + +func TestCodeModeToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.SimpleExpression, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "completed successfully") +} + +func TestCodeModeCallingCodeModeClientTools(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code calling code mode client tools + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingCodeModeTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "test") +} + +func TestCodeModeCallingMultipleCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code calling tools from multiple code mode clients + // Since we only have bifrostInternal, we'll test calling multiple tools from the same client + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.MultipleServerToolCalls, // This calls echo and add from BifrostClient + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestListToolFilesWithNoClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Don't register tools or set code mode - should have no code mode clients + toolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + // listToolFiles should still work but return empty/no servers message + if bifrostErr == nil && result != nil { + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "No servers", "Should indicate no servers") + } +} + +func TestListToolFilesWithOnlyNonCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + }) + require.NoError(t, err) + + // listToolFiles should not be available when no code mode clients exist + // But if it is called, it should return empty + toolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + if bifrostErr == nil && result != nil { + responseText := *result.Content.ContentStr + // Should indicate no servers or empty list + assert.True(t, + len(responseText) == 0 || + strings.Contains(responseText, "No servers") || strings.Contains(responseText, "servers/"), + "Should return empty or no servers message") + } +} + +func TestListToolFilesWithCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "servers/", "Should list servers") + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient server") +} + +func TestReadToolFileForNonExistentClient(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "NonExistentClient.d.ts", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "No server found", "Should indicate server not found") +} + +func TestReadToolFileForCodeModeClient(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "interface", "Should contain TypeScript interface declarations") + assert.Contains(t, responseText, "echo", "Should contain echo tool definition") +} + +func TestReadToolFileWithLineRange(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + "startLine": float64(1), + "endLine": float64(10), + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.NotEmpty(t, responseText, "Should return content") +} diff --git a/tests/core-mcp/utils.go b/tests/core-mcp/utils.go new file mode 100644 index 000000000..f48bb5f5b --- /dev/null +++ b/tests/core-mcp/utils.go @@ -0,0 +1,104 @@ +package mcp + +import ( + "encoding/json" + "fmt" + "slices" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createToolCall creates a tool call message for testing +func createToolCall(toolName string, arguments map[string]interface{}) schemas.ChatAssistantMessageToolCall { + argsJSON, _ := json.Marshal(arguments) + argsStr := string(argsJSON) + id := fmt.Sprintf("test-tool-call-%d", len(argsStr)) + toolType := "function" + + return schemas.ChatAssistantMessageToolCall{ + ID: &id, + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &toolName, + Arguments: argsStr, + }, + } +} + +// assertExecutionResult validates execution results +func assertExecutionResult(t *testing.T, result *schemas.ChatMessage, expectedSuccess bool, expectedLogs []string, expectedErrorKind string) { + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + + if expectedSuccess { + // Success case - should not contain error indicators (but allow console.error output) + assert.NotContains(t, responseText, "Execution runtime error", "Response should not contain execution runtime error for successful execution") + assert.NotContains(t, responseText, "Execution typescript error", "Response should not contain execution typescript error for successful execution") + assert.NotContains(t, responseText, "Error:", "Response should not contain Error: prefix for successful execution") + + // Check logs if expected + if len(expectedLogs) > 0 { + for _, expectedLog := range expectedLogs { + assert.Contains(t, responseText, expectedLog, "Response should contain expected log") + } + } + } else { + // Error case - should contain error information + assert.Contains(t, responseText, "error", "Response should contain error for failed execution") + + if expectedErrorKind != "" { + assert.Contains(t, responseText, expectedErrorKind, "Response should contain expected error kind") + } + } +} + +// assertResultContains validates that the result contains specific text +func assertResultContains(t *testing.T, result *schemas.ChatMessage, expectedText string) { + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, expectedText, "Response should contain expected text") +} + +// requireNoBifrostError asserts that bifrostErr is nil, using GetErrorMessage for better error reporting +func requireNoBifrostError(t *testing.T, bifrostErr *schemas.BifrostError, msgAndArgs ...interface{}) { + if bifrostErr != nil { + errorMsg := bifrost.GetErrorMessage(bifrostErr) + if len(msgAndArgs) > 0 { + require.Fail(t, fmt.Sprintf("Expected no error but got: %s", errorMsg), msgAndArgs...) + } else { + require.Fail(t, fmt.Sprintf("Expected no error but got: %s", errorMsg)) + } + } +} + +// canAutoExecuteTool checks if a tool can be auto-executed based on client config +func canAutoExecuteTool(toolName string, config schemas.MCPClientConfig) bool { + // First check if tool is in ToolsToExecute + if config.ToolsToExecute != nil { + if len(config.ToolsToExecute) == 0 { + return false // Empty list means no tools allowed + } + if !slices.Contains(config.ToolsToExecute, "*") && !slices.Contains(config.ToolsToExecute, toolName) { + return false // Tool not in allowed list + } + } else { + return false // nil means no tools allowed + } + + // Then check if tool is in ToolsToAutoExecute + if len(config.ToolsToAutoExecute) == 0 { + return false // No auto-execute tools configured + } + + return slices.Contains(config.ToolsToAutoExecute, "*") || slices.Contains(config.ToolsToAutoExecute, toolName) +} diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index b7669f2be..4b0509d2c 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -12,6 +12,7 @@ import ( "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/network" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" @@ -28,6 +29,7 @@ type ConfigManager interface { ReloadPricingManager(ctx context.Context) error ForceReloadPricing(ctx context.Context) error UpdateDropExcessRequests(ctx context.Context, value bool) + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error ReloadProxyConfig(ctx context.Context, config *configstoreTables.GlobalProxyConfig) error } @@ -215,6 +217,52 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.DropExcessRequests = payload.ClientConfig.DropExcessRequests } + // Validate MCP tool manager config values before updating + if payload.ClientConfig.MCPAgentDepth <= 0 { + logger.Warn("mcp_agent_depth must be greater than 0") + SendError(ctx, fasthttp.StatusBadRequest, "mcp_agent_depth must be greater than 0") + return + } + + if payload.ClientConfig.MCPToolExecutionTimeout <= 0 { + logger.Warn("mcp_tool_execution_timeout must be greater than 0") + SendError(ctx, fasthttp.StatusBadRequest, "mcp_tool_execution_timeout must be greater than 0") + return + } + + if payload.ClientConfig.MCPCodeModeBindingLevel != "" { + if payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelServer) && payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelTool) { + logger.Warn("mcp_code_mode_binding_level must be 'server' or 'tool'") + SendError(ctx, fasthttp.StatusBadRequest, "mcp_code_mode_binding_level must be 'server' or 'tool'") + return + } + } + + shouldReloadMCPToolManagerConfig := false + + if payload.ClientConfig.MCPAgentDepth != currentConfig.MCPAgentDepth { + updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth + shouldReloadMCPToolManagerConfig = true + } + + if payload.ClientConfig.MCPToolExecutionTimeout != currentConfig.MCPToolExecutionTimeout { + updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout + shouldReloadMCPToolManagerConfig = true + } + + if payload.ClientConfig.MCPCodeModeBindingLevel != "" && payload.ClientConfig.MCPCodeModeBindingLevel != currentConfig.MCPCodeModeBindingLevel { + updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel + shouldReloadMCPToolManagerConfig = true + } + + if shouldReloadMCPToolManagerConfig { + if err := h.configManager.UpdateMCPToolManagerConfig(ctx, updatedConfig.MCPAgentDepth, updatedConfig.MCPToolExecutionTimeout, updatedConfig.MCPCodeModeBindingLevel); err != nil { + logger.Warn(fmt.Sprintf("failed to update mcp tool manager config: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp tool manager config: %v", err)) + return + } + } + if !slices.Equal(payload.ClientConfig.PrometheusLabels, currentConfig.PrometheusLabels) { updatedConfig.PrometheusLabels = payload.ClientConfig.PrometheusLabels shouldReloadTelemetryPlugin = true @@ -232,7 +280,12 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.AllowDirectKeys = payload.ClientConfig.AllowDirectKeys updatedConfig.MaxRequestBodySizeMB = payload.ClientConfig.MaxRequestBodySizeMB updatedConfig.EnableLiteLLMFallbacks = payload.ClientConfig.EnableLiteLLMFallbacks - + updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth + updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout + // Only update MCPCodeModeBindingLevel if payload is non-empty to avoid clearing stored value + if payload.ClientConfig.MCPCodeModeBindingLevel != "" { + updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel + } // Validate LogRetentionDays if payload.ClientConfig.LogRetentionDays < 1 { logger.Warn("log_retention_days must be at least 1") diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index cc04a411b..05dd0d2a3 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -22,6 +22,7 @@ import ( // GovernanceManager is the interface for the governance manager type GovernanceManager interface { + GetGovernanceData() *governance.GovernanceData ReloadVirtualKey(ctx context.Context, id string) (*configstoreTables.TableVirtualKey, error) RemoveVirtualKey(ctx context.Context, id string) error ReloadTeam(ctx context.Context, id string) (*configstoreTables.TableTeam, error) @@ -38,6 +39,9 @@ type GovernanceHandler struct { // NewGovernanceHandler creates a new governance handler instance func NewGovernanceHandler(manager GovernanceManager, configStore configstore.ConfigStore) (*GovernanceHandler, error) { + if manager == nil { + return nil, fmt.Errorf("governance manager is required") + } if configStore == nil { return nil, fmt.Errorf("config store is required") } @@ -171,12 +175,30 @@ func (h *GovernanceHandler) RegisterRoutes(r *router.Router, middlewares ...lib. r.GET("/api/governance/customers/{customer_id}", lib.ChainMiddlewares(h.getCustomer, middlewares...)) r.PUT("/api/governance/customers/{customer_id}", lib.ChainMiddlewares(h.updateCustomer, middlewares...)) r.DELETE("/api/governance/customers/{customer_id}", lib.ChainMiddlewares(h.deleteCustomer, middlewares...)) + + // Budget and Rate Limit GET operations + r.GET("/api/governance/budgets", lib.ChainMiddlewares(h.getBudgets, middlewares...)) + r.GET("/api/governance/rate-limits", lib.ChainMiddlewares(h.getRateLimits, middlewares...)) } // Virtual Key CRUD Operations // getVirtualKeys handles GET /api/governance/virtual-keys - Get all virtual keys with relationships func (h *GovernanceHandler) getVirtualKeys(ctx *fasthttp.RequestCtx) { + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + SendJSON(ctx, map[string]interface{}{ + "virtual_keys": data.VirtualKeys, + "count": len(data.VirtualKeys), + }) + return + } // Preload all relationships for complete information virtualKeys, err := h.configStore.GetVirtualKeys(ctx) if err != nil { @@ -285,29 +307,29 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { } } - // Get keys for this provider config if specified - var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { - var err error - keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) - if err != nil { - return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) - } - if len(keys) != len(pc.KeyIDs) { - return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + // Get keys for this provider config if specified + var keys []configstoreTables.TableKey + if len(pc.KeyIDs) > 0 { + var err error + keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) + if err != nil { + return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) + } + if len(keys) != len(pc.KeyIDs) { + return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + } } - } - providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ - VirtualKeyID: vk.ID, - Provider: pc.Provider, - Weight: pc.Weight, - AllowedModels: pc.AllowedModels, - Keys: keys, - } + providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ + VirtualKeyID: vk.ID, + Provider: pc.Provider, + Weight: pc.Weight, + AllowedModels: pc.AllowedModels, + Keys: keys, + } - // Create budget for provider config if provided - if pc.Budget != nil { + // Create budget for provider config if provided + if pc.Budget != nil { budget := configstoreTables.TableBudget{ ID: uuid.NewString(), MaxLimit: pc.Budget.MaxLimit, @@ -397,6 +419,25 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { // getVirtualKey handles GET /api/governance/virtual-keys/{vk_id} - Get a specific virtual key func (h *GovernanceHandler) getVirtualKey(ctx *fasthttp.RequestCtx) { vkID := ctx.UserValue("vk_id").(string) + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + for _, vk := range data.VirtualKeys { + if vk.ID == vkID { + SendJSON(ctx, map[string]interface{}{ + "virtual_key": vk, + }) + return + } + } + SendError(ctx, 404, "Virtual key not found") + return + } vk, err := h.configStore.GetVirtualKey(ctx, vkID) if err != nil { if errors.Is(err, configstore.ErrNotFound) { @@ -531,6 +572,9 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { rateLimit.RequestResetDuration = req.RateLimit.RequestResetDuration } + if err := validateRateLimit(&rateLimit); err != nil { + return err + } if err := h.configStore.UpdateRateLimit(ctx, &rateLimit, tx); err != nil { return err } @@ -588,29 +632,29 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return fmt.Errorf("both max_limit and reset_duration are required when creating a new provider budget") } } - // Get keys for this provider config if specified - var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { - var err error - keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) - if err != nil { - return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) - } - if len(keys) != len(pc.KeyIDs) { - return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) - } - } + // Get keys for this provider config if specified + var keys []configstoreTables.TableKey + if len(pc.KeyIDs) > 0 { + var err error + keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) + if err != nil { + return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) + } + if len(keys) != len(pc.KeyIDs) { + return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + } + } - // Create new provider config - providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ - VirtualKeyID: vk.ID, - Provider: pc.Provider, - Weight: pc.Weight, - AllowedModels: pc.AllowedModels, - Keys: keys, - } - // Create budget for provider config if provided - if pc.Budget != nil { + // Create new provider config + providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ + VirtualKeyID: vk.ID, + Provider: pc.Provider, + Weight: pc.Weight, + AllowedModels: pc.AllowedModels, + Keys: keys, + } + // Create budget for provider config if provided + if pc.Budget != nil { budget := configstoreTables.TableBudget{ ID: uuid.NewString(), MaxLimit: *pc.Budget.MaxLimit, @@ -648,32 +692,33 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { if err := h.configStore.CreateVirtualKeyProviderConfig(ctx, providerConfig, tx); err != nil { return err } - } else { - // Update existing provider config - existing, ok := existingConfigsMap[*pc.ID] - if !ok { - return fmt.Errorf("provider config %d does not belong to this virtual key", *pc.ID) - } - requestConfigsMap[*pc.ID] = true - existing.Provider = pc.Provider - existing.Weight = pc.Weight - existing.AllowedModels = pc.AllowedModels - - // Get keys for this provider config if specified - var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { - var err error - keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) - if err != nil { - return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) + } else { + // Update existing provider config + existing, ok := existingConfigsMap[*pc.ID] + if !ok { + return fmt.Errorf("provider config %d does not belong to this virtual key", *pc.ID) } - if len(keys) != len(pc.KeyIDs) { - return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + requestConfigsMap[*pc.ID] = true + existing.Provider = pc.Provider + existing.Weight = pc.Weight + existing.AllowedModels = pc.AllowedModels + + // Update keys only if KeyIDs field was present in the request. + // - If pc.KeyIDs is nil (field omitted), leave existing.Keys unchanged. + // - If pc.KeyIDs is an empty slice (field present but empty), clear keys. + // - If pc.KeyIDs has values, update keys accordingly. + if pc.KeyIDs != nil { + keys, err := h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) + if err != nil { + return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) + } + if len(keys) != len(pc.KeyIDs) { + return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + } + existing.Keys = keys } - } - existing.Keys = keys - // Handle budget updates for provider config + // Handle budget updates for provider config if pc.Budget != nil { if existing.BudgetID != nil { // Update existing budget @@ -740,6 +785,9 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { if pc.RateLimit.RequestResetDuration != nil { rateLimit.RequestResetDuration = pc.RateLimit.RequestResetDuration } + if err := validateRateLimit(&rateLimit); err != nil { + return err + } if err := h.configStore.UpdateRateLimit(ctx, &rateLimit, tx); err != nil { return err } @@ -885,6 +933,7 @@ func (h *GovernanceHandler) deleteVirtualKey(ctx *fasthttp.RequestCtx) { SendError(ctx, 404, "Virtual key not found") return } + logger.Error("failed to delete virtual key: %v", err) SendError(ctx, 500, "Failed to delete virtual key") return } @@ -898,6 +947,33 @@ func (h *GovernanceHandler) deleteVirtualKey(ctx *fasthttp.RequestCtx) { // getTeams handles GET /api/governance/teams - Get all teams func (h *GovernanceHandler) getTeams(ctx *fasthttp.RequestCtx) { customerID := string(ctx.QueryArgs().Peek("customer_id")) + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + if customerID != "" { + teams := make(map[string]*configstoreTables.TableTeam) + for _, team := range data.Teams { + if team.CustomerID != nil && *team.CustomerID == customerID { + teams[team.ID] = team + } + } + SendJSON(ctx, map[string]interface{}{ + "teams": teams, + "count": len(teams), + }) + } else { + SendJSON(ctx, map[string]interface{}{ + "teams": data.Teams, + "count": len(data.Teams), + }) + } + return + } // Preload relationships for complete information teams, err := h.configStore.GetTeams(ctx, customerID) if err != nil { @@ -980,6 +1056,24 @@ func (h *GovernanceHandler) createTeam(ctx *fasthttp.RequestCtx) { // getTeam handles GET /api/governance/teams/{team_id} - Get a specific team func (h *GovernanceHandler) getTeam(ctx *fasthttp.RequestCtx) { teamID := ctx.UserValue("team_id").(string) + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + team, ok := data.Teams[teamID] + if !ok { + SendError(ctx, 404, "Team not found") + return + } + SendJSON(ctx, map[string]interface{}{ + "team": team, + }) + return + } team, err := h.configStore.GetTeam(ctx, teamID) if err != nil { if errors.Is(err, configstore.ErrNotFound) { @@ -1112,6 +1206,20 @@ func (h *GovernanceHandler) deleteTeam(ctx *fasthttp.RequestCtx) { // getCustomers handles GET /api/governance/customers - Get all customers func (h *GovernanceHandler) getCustomers(ctx *fasthttp.RequestCtx) { + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + SendJSON(ctx, map[string]interface{}{ + "customers": data.Customers, + "count": len(data.Customers), + }) + return + } customers, err := h.configStore.GetCustomers(ctx) if err != nil { logger.Error("failed to retrieve customers: %v", err) @@ -1190,6 +1298,24 @@ func (h *GovernanceHandler) createCustomer(ctx *fasthttp.RequestCtx) { // getCustomer handles GET /api/governance/customers/{customer_id} - Get a specific customer func (h *GovernanceHandler) getCustomer(ctx *fasthttp.RequestCtx) { customerID := ctx.UserValue("customer_id").(string) + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + customer, ok := data.Customers[customerID] + if !ok { + SendError(ctx, 404, "Customer not found") + return + } + SendJSON(ctx, map[string]interface{}{ + "customer": customer, + }) + return + } customer, err := h.configStore.GetCustomer(ctx, customerID) if err != nil { if errors.Is(err, configstore.ErrNotFound) { @@ -1316,6 +1442,64 @@ func (h *GovernanceHandler) deleteCustomer(ctx *fasthttp.RequestCtx) { }) } +// Budget and Rate Limit GET operations + +// getBudgets handles GET /api/governance/budgets - Get all budgets +func (h *GovernanceHandler) getBudgets(ctx *fasthttp.RequestCtx) { + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + SendJSON(ctx, map[string]interface{}{ + "budgets": data.Budgets, + "count": len(data.Budgets), + }) + return + } + budgets, err := h.configStore.GetBudgets(ctx) + if err != nil { + logger.Error("failed to retrieve budgets: %v", err) + SendError(ctx, 500, "failed to retrieve budgets") + return + } + SendJSON(ctx, map[string]interface{}{ + "budgets": budgets, + "count": len(budgets), + }) +} + +// getRateLimits handles GET /api/governance/rate-limits - Get all rate limits +func (h *GovernanceHandler) getRateLimits(ctx *fasthttp.RequestCtx) { + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + SendJSON(ctx, map[string]interface{}{ + "rate_limits": data.RateLimits, + "count": len(data.RateLimits), + }) + return + } + rateLimits, err := h.configStore.GetRateLimits(ctx) + if err != nil { + logger.Error("failed to retrieve rate limits: %v", err) + SendError(ctx, 500, "failed to retrieve rate limits") + return + } + SendJSON(ctx, map[string]interface{}{ + "rate_limits": rateLimits, + "count": len(rateLimits), + }) +} + // validateRateLimit validates the rate limit func validateRateLimit(rateLimit *configstoreTables.TableRateLimit) error { if rateLimit.TokenMaxLimit != nil && (*rateLimit.TokenMaxLimit < 0 || *rateLimit.TokenMaxLimit == 0) { diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 8900206bf..52ccabc98 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -1185,7 +1185,6 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge ctx.SetContentType("text/event-stream") ctx.Response.Header.Set("Cache-Control", "no-cache") ctx.Response.Header.Set("Connection", "keep-alive") - ctx.Response.Header.Set("Access-Control-Allow-Origin", "*") // Get the streaming channel stream, bifrostErr := getStream() diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 88b710d20..56a1c94a2 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -8,6 +8,7 @@ import ( "fmt" "slices" "sort" + "strings" "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" @@ -51,6 +52,21 @@ func (h *MCPHandler) RegisterRoutes(r *router.Router, middlewares ...lib.Bifrost // executeTool handles POST /v1/mcp/tool/execute - Execute MCP tool func (h *MCPHandler) executeTool(ctx *fasthttp.RequestCtx) { + // Check format query parameter + format := strings.ToLower(string(ctx.QueryArgs().Peek("format"))) + switch format { + case "chat", "": + h.executeChatMCPTool(ctx) + case "responses": + h.executeResponsesMCPTool(ctx) + default: + SendError(ctx, fasthttp.StatusBadRequest, "Invalid format value, must be 'chat' or 'responses'") + return + } +} + +// executeChatMCPTool handles POST /v1/mcp/tool/execute?format=chat - Execute MCP tool +func (h *MCPHandler) executeChatMCPTool(ctx *fasthttp.RequestCtx) { var req schemas.ChatAssistantMessageToolCall if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) @@ -72,14 +88,47 @@ func (h *MCPHandler) executeTool(ctx *fasthttp.RequestCtx) { } // Execute MCP tool - resp, bifrostErr := h.client.ExecuteMCPTool(*bifrostCtx, req) + toolMessage, bifrostErr := h.client.ExecuteChatMCPTool(*bifrostCtx, req) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return } // Send successful response - SendJSON(ctx, resp) + SendJSON(ctx, toolMessage) +} + +// executeResponsesMCPTool handles POST /v1/mcp/tool/execute?format=responses - Execute MCP tool +func (h *MCPHandler) executeResponsesMCPTool(ctx *fasthttp.RequestCtx) { + var req schemas.ResponsesToolMessage + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) + return + } + + // Validate required fields + if req.Name == nil || *req.Name == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Tool function name is required") + return + } + + // Convert context + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false) + defer cancel() // Ensure cleanup on function exit + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") + return + } + + // Execute MCP tool + toolMessage, bifrostErr := h.client.ExecuteResponsesMCPTool(*bifrostCtx, &req) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + // Send successful response + SendJSON(ctx, toolMessage) } // getMCPClients handles GET /api/mcp/clients - Get all MCP clients @@ -119,7 +168,7 @@ func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { clients = append(clients, schemas.MCPClient{ Config: h.store.RedactMCPClientConfig(connectedClient.Config), Tools: sortedTools, - State: connectedClient.State, + State: connectedClient.State, // Use the state from MCPClientState }) } else { // Client is in config but not connected, mark as errored @@ -189,13 +238,28 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_execute: %v", err)) return } + + // Auto-clear tools_to_auto_execute if tools_to_execute is empty + // If no tools are allowed to execute, no tools can be auto-executed + if len(req.ToolsToExecute) == 0 { + req.ToolsToAutoExecute = []string{} + } + + if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_auto_execute: %v", err)) + return + } + if err := validateMCPClientName(req.Name); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) + return + } if err := h.mcpManager.AddMCPClient(ctx, req); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to add MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to connect MCP client: %v", err)) return } SendJSON(ctx, map[string]any{ "status": "success", - "message": "MCP client added successfully", + "message": "MCP client connected successfully", }) } @@ -219,6 +283,24 @@ func (h *MCPHandler) editMCPClient(ctx *fasthttp.RequestCtx) { return } + // Auto-clear tools_to_auto_execute if tools_to_execute is empty + // If no tools are allowed to execute, no tools can be auto-executed + if len(req.ToolsToExecute) == 0 { + req.ToolsToAutoExecute = []string{} + } + + // Validate tools_to_auto_execute + if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_auto_execute: %v", err)) + return + } + + // Validate client name + if err := validateMCPClientName(req.Name); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) + return + } + if err := h.mcpManager.EditMCPClient(ctx, id, req); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to edit MCP client: %v", err)) return @@ -280,3 +362,69 @@ func validateToolsToExecute(toolsToExecute []string) error { return nil } + +func validateToolsToAutoExecute(toolsToAutoExecute []string, toolsToExecute []string) error { + if len(toolsToAutoExecute) > 0 { + // Check if wildcard "*" is combined with other tool names + hasWildcard := slices.Contains(toolsToAutoExecute, "*") + if hasWildcard && len(toolsToAutoExecute) > 1 { + return fmt.Errorf("wildcard '*' cannot be combined with other tool names") + } + + // Check for duplicate entries + seen := make(map[string]bool) + for _, tool := range toolsToAutoExecute { + if seen[tool] { + return fmt.Errorf("duplicate tool name '%s'", tool) + } + seen[tool] = true + } + + // Check that all tools in ToolsToAutoExecute are also in ToolsToExecute + // Create a set of allowed tools from ToolsToExecute + allowedTools := make(map[string]bool) + hasWildcardInExecute := slices.Contains(toolsToExecute, "*") + if hasWildcardInExecute { + // If "*" is in ToolsToExecute, all tools are allowed + return nil + } + for _, tool := range toolsToExecute { + allowedTools[tool] = true + } + + // Validate each tool in ToolsToAutoExecute + for _, tool := range toolsToAutoExecute { + if tool == "*" { + // Wildcard is allowed if "*" is in ToolsToExecute + if !hasWildcardInExecute { + return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) + } + } else if !allowedTools[tool] { + return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) + } + } + } + + return nil +} + +func validateMCPClientName(name string) error { + if strings.TrimSpace(name) == "" { + return fmt.Errorf("client name is required") + } + for _, r := range name { + if r > 127 { // non-ASCII + return fmt.Errorf("name must contain only ASCII characters") + } + } + if strings.Contains(name, "-") { + return fmt.Errorf("client name cannot contain hyphens") + } + if strings.Contains(name, " ") { + return fmt.Errorf("client name cannot contain spaces") + } + if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { + return fmt.Errorf("client name cannot start with a number") + } + return nil +} diff --git a/transports/bifrost-http/handlers/mcp_server.go b/transports/bifrost-http/handlers/mcp_server.go new file mode 100644 index 000000000..1bdf5b6ca --- /dev/null +++ b/transports/bifrost-http/handlers/mcp_server.go @@ -0,0 +1,392 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains MCP (Model Context Protocol) server implementation for HTTP streaming. +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "slices" + "strings" + "sync" + + "github.com/fasthttp/router" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// MCPToolExecutor interface defines the method needed for executing MCP tools +type MCPToolManager interface { + GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool + ExecuteChatMCPTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) + ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError) +} + +// MCPServerHandler manages HTTP requests for MCP server operations +// It implements the MCP protocol over HTTP streaming (SSE) for MCP clients +type MCPServerHandler struct { + toolManager MCPToolManager + globalMCPServer *server.MCPServer + vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server + config *lib.Config + mu sync.RWMutex +} + +// NewMCPServerHandler creates a new MCP server handler instance +func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MCPToolManager) (*MCPServerHandler, error) { + if config == nil { + return nil, fmt.Errorf("config is required") + } + if toolManager == nil { + return nil, fmt.Errorf("tool manager is required") + } + + // Create MCP server instance using mcp-go + globalMCPServer := server.NewMCPServer( + "global", + version, + server.WithToolCapabilities(true), + ) + + handler := &MCPServerHandler{ + toolManager: toolManager, + globalMCPServer: globalMCPServer, + config: config, + vkMCPServers: make(map[string]*server.MCPServer), + } + + if err := handler.SyncAllMCPServers(ctx); err != nil { + return nil, fmt.Errorf("failed to sync all MCP servers: %w", err) + } + + return handler, nil +} + +// RegisterRoutes registers the MCP server route +func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { + // MCP server endpoint - supports both POST (JSON-RPC) and GET (SSE) + r.POST("/mcp", lib.ChainMiddlewares(h.handleMCPServer, middlewares...)) + r.GET("/mcp", lib.ChainMiddlewares(h.handleMCPServerSSE, middlewares...)) +} + +// handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages +func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { + mcpServer, err := h.getMCPServerForRequest(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) + return + } + + // Convert context + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false) + defer cancel() + + // Use mcp-go server to handle the request + // HandleMessage processes JSON-RPC messages and returns appropriate responses + response := mcpServer.HandleMessage(*bifrostCtx, ctx.PostBody()) + + // Check if response is nil (notification - no response needed) + if response == nil { + ctx.SetStatusCode(fasthttp.StatusOK) + return + } + + // Marshal and send response + responseJSON, err := json.Marshal(response) + if err != nil { + logger.Warn(fmt.Sprintf("Failed to marshal MCP response: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err)) + return + } + + ctx.SetContentType("application/json") + ctx.SetBody(responseJSON) +} + +// handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming +func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { + _, err := h.getMCPServerForRequest(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) + return + } + + // Set SSE headers + ctx.SetContentType("text/event-stream") + ctx.Response.Header.Set("Cache-Control", "no-cache") + ctx.Response.Header.Set("Connection", "keep-alive") + + // Convert context + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false) + + // Use streaming response writer + ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { + defer func() { + cancel() + _ = w.Flush() + }() + + // Send initial connection message + initMessage := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "connection/opened", + } + if initJSON, err := json.Marshal(initMessage); err == nil { + fmt.Fprintf(w, "data: %s\n\n", initJSON) + w.Flush() + } + + // Wait for context cancellation (client disconnect or server-side cancel) + <-(*bifrostCtx).Done() + }) +} + +// Sync methods for MCP servers + +func (h *MCPServerHandler) SyncAllMCPServers(ctx context.Context) error { + h.mu.Lock() + defer h.mu.Unlock() + availableTools := h.toolManager.GetAvailableMCPTools(ctx) + h.syncServer(h.globalMCPServer, availableTools) + logger.Debug("Synced global MCP server with %d tools", len(availableTools)) + + // initialize vkMCPServers map + if h.config.ConfigStore != nil { + virtualKeys, err := h.config.ConfigStore.GetVirtualKeys(ctx) + if err != nil { + return fmt.Errorf("failed to get virtual keys: %w", err) + } + h.vkMCPServers = make(map[string]*server.MCPServer) + for i := range virtualKeys { + vk := &virtualKeys[i] + h.vkMCPServers[vk.Value] = server.NewMCPServer( + vk.Name, + version, + server.WithToolCapabilities(true), + ) + availableTools := h.fetchToolsForVK(vk) + h.syncServer(h.vkMCPServers[vk.Value], availableTools) + logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools)) + } + } + return nil +} + +func (h *MCPServerHandler) SyncVKMCPServer(vk *tables.TableVirtualKey) { + h.mu.Lock() + defer h.mu.Unlock() + vkServer, ok := h.vkMCPServers[vk.Value] + if !ok { + // Add new server + vkServer = server.NewMCPServer( + vk.Name, + version, + server.WithToolCapabilities(true), + ) + h.vkMCPServers[vk.Value] = vkServer + } + availableTools := h.fetchToolsForVK(vk) + h.syncServer(vkServer, availableTools) + h.vkMCPServers[vk.Value] = vkServer + logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools)) +} + +func (h *MCPServerHandler) DeleteVKMCPServer(vkValue string) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.vkMCPServers, vkValue) +} + +func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools []schemas.ChatTool) { + // Clear existing tools + toolMap := server.ListTools() + for toolName, _ := range toolMap { + server.DeleteTools(toolName) + } + + // Register tools from all connected clients + for _, tool := range availableTools { + // Only process function tools (skip custom tools) + if tool.Function == nil { + continue + } + + // Capture tool name for closure + toolName := tool.Function.Name + + handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Convert to Bifrost tool call format + toolCallType := "function" + toolCallID := fmt.Sprintf("mcp-%s", toolName) + argsJSON, jsonErr := json.Marshal(request.GetArguments()) + if jsonErr != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal tool arguments: %v", jsonErr)), nil + } + toolCall := schemas.ChatAssistantMessageToolCall{ + ID: &toolCallID, + Type: &toolCallType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &toolName, + Arguments: string(argsJSON), + }, + } + + // Execute the tool via tool executor + toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, toolCall) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil + } + + // Extract content from tool message + var resultText string + if toolMessage != nil && toolMessage.Content != nil { + // Handle ContentStr (string content) + if toolMessage.Content.ContentStr != nil { + resultText = *toolMessage.Content.ContentStr + } else if toolMessage.Content.ContentBlocks != nil { + // Handle ContentBlocks (structured content) + for _, block := range toolMessage.Content.ContentBlocks { + if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil { + resultText += *block.Text + } + } + } + } + + // Return result using mcp-go helper + return mcp.NewToolResultText(resultText), nil + } + + // Convert description from *string to string + description := "" + if tool.Function.Description != nil { + description = *tool.Function.Description + } + + // Convert Parameters to mcp.ToolInputSchema + var inputSchema mcp.ToolInputSchema + if tool.Function.Parameters != nil { + inputSchema.Type = tool.Function.Parameters.Type + if tool.Function.Parameters.Properties != nil { + // Convert *map[string]interface{} to map[string]any + props := make(map[string]any) + for k, v := range *tool.Function.Parameters.Properties { + props[k] = v + } + inputSchema.Properties = props + } + if tool.Function.Parameters.Required != nil { + inputSchema.Required = tool.Function.Parameters.Required + } + } else { + // Default to empty object schema if no parameters + inputSchema.Type = "object" + inputSchema.Properties = make(map[string]any) + } + + // Register tool with the server + server.AddTool(mcp.Tool{ + Name: toolName, + Description: description, + InputSchema: inputSchema, + }, handler) + } +} + +// fetchToolsForVK fetches the tools for a given virtual key value. +// vkValue is the virtual key value for the server, if empty, all tools will be fetched for global mcp server. +// Returns a map of tool name to tool. +func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) []schemas.ChatTool { + ctx := context.Background() + + if len(vk.MCPConfigs) > 0 { + executeOnlyTools := make([]string, 0) + for _, vkMcpConfig := range vk.MCPConfigs { + if len(vkMcpConfig.ToolsToExecute) == 0 { + // No tools specified in virtual key config - skip this client entirely + continue + } + + // Handle wildcard in virtual key config - allow all tools from this client + if slices.Contains(vkMcpConfig.ToolsToExecute, "*") { + // Virtual key uses wildcard - use client-specific wildcard + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s/*", vkMcpConfig.MCPClient.Name)) + continue + } + + for _, tool := range vkMcpConfig.ToolsToExecute { + if tool != "" { + // Add the tool - client config filtering will be handled by mcp.go + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s/%s", vkMcpConfig.MCPClient.Name, tool)) + } + } + } + + // Set even when empty to exclude tools when no tools are present in the virtual key config + ctx = context.WithValue(ctx, schemas.BifrostContextKey("mcp-include-tools"), executeOnlyTools) + } + + return h.toolManager.GetAvailableMCPTools(ctx) +} + +// Utility methods + +func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, error) { + h.mu.RLock() + defer h.mu.RUnlock() + + h.config.Mu.RLock() + enforceVK := h.config.ClientConfig.EnforceGovernanceHeader + h.config.Mu.RUnlock() + + vk := getVKFromRequest(ctx) + + // Return global MCP server if not enforcing virtual key header and no virtual key is provided + if !enforceVK && vk == "" { + return h.globalMCPServer, nil + } + + // Check if virtual key is provided + if vk == "" { + return nil, fmt.Errorf("virtual key header is required to access MCP server.") + } + + // Check if vk exists in the map + vkServer, ok := h.vkMCPServers[vk] + if !ok { + return nil, fmt.Errorf("virtual key not found.") + } + + return vkServer, nil +} + +func getVKFromRequest(ctx *fasthttp.RequestCtx) string { + if value := strings.TrimSpace(string(ctx.Request.Header.Peek(string(schemas.BifrostContextKeyVirtualKey)))); value != "" { + return value + } + + authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization"))) + if authHeader != "" { + if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + token := strings.TrimSpace(authHeader[7:]) + if token != "" && strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) { + return token + } + } + } + + if apiKey := strings.TrimSpace(string(ctx.Request.Header.Peek("x-api-key"))); apiKey != "" { + if strings.HasPrefix(strings.ToLower(apiKey), governance.VirtualKeyPrefix) { + return apiKey + } + } + + return "" +} diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index e981ef2c9..544d84fb2 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -12,7 +12,6 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/encrypt" - "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -47,7 +46,7 @@ func CorsMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware { } // TransportInterceptorMiddleware collects all plugin interceptors and calls them one by one -func TransportInterceptorMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware { +func TransportInterceptorMiddleware(config *lib.Config, enterpriseOverrides lib.EnterpriseOverrides) lib.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { // Get plugins from config - lock-free read @@ -56,10 +55,14 @@ func TransportInterceptorMiddleware(config *lib.Config) lib.BifrostHTTPMiddlewar next(ctx) return } + if enterpriseOverrides == nil { + next(ctx) + return + } // If governance plugin is not loaded, skip interception hasGovernance := false for _, p := range plugins { - if p.GetName() == governance.PluginName { + if p.GetName() == enterpriseOverrides.GetGovernancePluginName() { hasGovernance = true break } diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index b5a0fff8b..719f60c84 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -44,6 +44,11 @@ const ( MaxRetryBackoff = 1000000 * time.Millisecond // Maximum retry backoff: 1000000ms (1000 seconds) ) +const ( + DBLookupMaxRetries = 5 + DBLookupDelay = 1 * time.Second +) + // ConfigData represents the configuration data for the Bifrost HTTP transport. // It contains the client configuration, provider configurations, MCP configuration, // vector store configuration, config store configuration, and logs store configuration. @@ -225,6 +230,9 @@ var DefaultClientConfig = configstore.ClientConfig{ AllowDirectKeys: false, AllowedOrigins: []string{"*"}, MaxRequestBodySizeMB: 100, + MCPAgentDepth: 10, + MCPToolExecutionTimeout: 30, + MCPCodeModeBindingLevel: string(schemas.CodeModeBindingLevelServer), EnableLiteLLMFallbacks: false, } @@ -265,7 +273,7 @@ func (c *Config) initializeEncryption(configKey string) error { // - Case conversion for provider names (e.g., "OpenAI" -> "openai") // - In-memory storage for ultra-fast access during request processing // - Graceful handling of missing config files -func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { +func LoadConfig(ctx context.Context, configDirPath string, EnterpriseOverrides EnterpriseOverrides) (*Config, error) { // Initialize separate database connections for optimal performance at scale configFilePath := filepath.Join(configDirPath, "config.json") configDBPath := filepath.Join(configDirPath, "config.db") @@ -288,7 +296,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { // If config file doesn't exist, we will directly use the config store (create one if it doesn't exist) if os.IsNotExist(err) { logger.Info("config file not found at path: %s, initializing with default values", absConfigFilePath) - return loadConfigFromDefaults(ctx, config, configDBPath, logsDBPath) + return loadConfigFromDefaults(ctx, config, configDBPath, logsDBPath, EnterpriseOverrides) } return nil, fmt.Errorf("failed to read config file: %w", err) } @@ -344,12 +352,12 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { } // If config file exists, we will use it to bootstrap config tables logger.Info("loading configuration from: %s", absConfigFilePath) - return loadConfigFromFile(ctx, config, data) + return loadConfigFromFile(ctx, config, data, EnterpriseOverrides) } // loadConfigFromFile initializes configuration from a JSON config file. // It merges config file data with existing database config, with store taking priority. -func loadConfigFromFile(ctx context.Context, config *Config, data []byte) (*Config, error) { +func loadConfigFromFile(ctx context.Context, config *Config, data []byte, EnterpriseOverrides EnterpriseOverrides) (*Config, error) { var configData ConfigData if err := json.Unmarshal(data, &configData); err != nil { return nil, fmt.Errorf("failed to unmarshal config: %w", err) @@ -390,7 +398,7 @@ func loadConfigFromFile(ctx context.Context, config *Config, data []byte) (*Conf loadEnvKeysFromFile(ctx, config) // Initialize framework config and pricing manager - initFrameworkConfigFromFile(ctx, config, &configData) + initFrameworkConfigFromFile(ctx, config, &configData, EnterpriseOverrides) // Initialize encryption if err = initEncryptionFromFile(config, &configData); err != nil { @@ -460,7 +468,51 @@ func loadClientConfigFromFile(ctx context.Context, config *Config, configData *C // Merge with config file if present if configData.Client != nil { - mergeClientConfig(&config.ClientConfig, configData.Client) + logger.Debug("merging client config from config file with store") + // DB takes priority, but fill in empty/zero values from config file + if config.ClientConfig.InitialPoolSize == 0 && configData.Client.InitialPoolSize != 0 { + config.ClientConfig.InitialPoolSize = configData.Client.InitialPoolSize + } + if len(config.ClientConfig.PrometheusLabels) == 0 && len(configData.Client.PrometheusLabels) > 0 { + config.ClientConfig.PrometheusLabels = configData.Client.PrometheusLabels + } + if len(config.ClientConfig.AllowedOrigins) == 0 && len(configData.Client.AllowedOrigins) > 0 { + config.ClientConfig.AllowedOrigins = configData.Client.AllowedOrigins + } + if config.ClientConfig.MaxRequestBodySizeMB == 0 && configData.Client.MaxRequestBodySizeMB != 0 { + config.ClientConfig.MaxRequestBodySizeMB = configData.Client.MaxRequestBodySizeMB + } + // Boolean fields: only override if DB has false and config file has true + if !config.ClientConfig.DropExcessRequests && configData.Client.DropExcessRequests { + config.ClientConfig.DropExcessRequests = configData.Client.DropExcessRequests + } + if !config.ClientConfig.EnableLogging && configData.Client.EnableLogging { + config.ClientConfig.EnableLogging = configData.Client.EnableLogging + } + if !config.ClientConfig.DisableContentLogging && configData.Client.DisableContentLogging { + config.ClientConfig.DisableContentLogging = configData.Client.DisableContentLogging + } + if !config.ClientConfig.EnableGovernance && configData.Client.EnableGovernance { + config.ClientConfig.EnableGovernance = configData.Client.EnableGovernance + } + if !config.ClientConfig.EnforceGovernanceHeader && configData.Client.EnforceGovernanceHeader { + config.ClientConfig.EnforceGovernanceHeader = configData.Client.EnforceGovernanceHeader + } + if !config.ClientConfig.AllowDirectKeys && configData.Client.AllowDirectKeys { + config.ClientConfig.AllowDirectKeys = configData.Client.AllowDirectKeys + } + if !config.ClientConfig.EnableLiteLLMFallbacks && configData.Client.EnableLiteLLMFallbacks { + config.ClientConfig.EnableLiteLLMFallbacks = configData.Client.EnableLiteLLMFallbacks + } + if config.ClientConfig.MCPAgentDepth == 0 && configData.Client.MCPAgentDepth != 0 { + config.ClientConfig.MCPAgentDepth = configData.Client.MCPAgentDepth + } + if config.ClientConfig.MCPToolExecutionTimeout == 0 && configData.Client.MCPToolExecutionTimeout != 0 { + config.ClientConfig.MCPToolExecutionTimeout = configData.Client.MCPToolExecutionTimeout + } + if config.ClientConfig.MCPCodeModeBindingLevel == "" && configData.Client.MCPCodeModeBindingLevel != "" { + config.ClientConfig.MCPCodeModeBindingLevel = configData.Client.MCPCodeModeBindingLevel + } // Update store with merged config if config.ConfigStore != nil { logger.Debug("updating merged client config in store") @@ -803,7 +855,6 @@ func loadMCPConfigFromFile(ctx context.Context, config *Config, configData *Conf if mcpConfig != nil { config.MCPConfig = mcpConfig - // Merge with config file if present if configData.MCP != nil && len(configData.MCP.ClientConfigs) > 0 { mergeMCPConfig(ctx, config, configData, mcpConfig) @@ -1455,7 +1506,7 @@ func loadEnvKeysFromFile(ctx context.Context, config *Config) { } // initFrameworkConfigFromFile initializes framework config and pricing manager from file -func initFrameworkConfigFromFile(ctx context.Context, config *Config, configData *ConfigData) { +func initFrameworkConfigFromFile(ctx context.Context, config *Config, configData *ConfigData, EnterpriseOverrides EnterpriseOverrides) { pricingConfig := &modelcatalog.Config{} if config.ConfigStore != nil { frameworkConfig, err := config.ConfigStore.GetFrameworkConfig(ctx) @@ -1479,9 +1530,21 @@ func initFrameworkConfigFromFile(ctx context.Context, config *Config, configData Pricing: pricingConfig, } - pricingManager, err := modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, logger) - if err != nil { - logger.Warn("failed to initialize pricing manager: %v", err) + var pricingManager *modelcatalog.ModelCatalog + var err error + + // Check if EnterpriseOverrides is provided, otherwise use default initialization + if EnterpriseOverrides != nil { + pricingManager, err = EnterpriseOverrides.LoadPricingManager(ctx, pricingConfig, config.ConfigStore) + if err != nil { + logger.Warn("failed to load pricing manager: %v", err) + } + } else { + // Use default modelcatalog initialization when no enterprise overrides are provided + pricingManager, err = modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, nil, logger) + if err != nil { + logger.Warn("failed to initialize pricing manager: %v", err) + } } config.PricingManager = pricingManager } @@ -1514,7 +1577,7 @@ func initEncryptionFromFile(config *Config, configData *ConfigData) error { // loadConfigFromDefaults initializes configuration when no config file exists. // It creates a default SQLite config store and loads/creates default configurations. -func loadConfigFromDefaults(ctx context.Context, config *Config, configDBPath, logsDBPath string) (*Config, error) { +func loadConfigFromDefaults(ctx context.Context, config *Config, configDBPath, logsDBPath string, EnterpriseOverrides EnterpriseOverrides) (*Config, error) { var err error // Initialize default config store @@ -1556,7 +1619,7 @@ func loadConfigFromDefaults(ctx context.Context, config *Config, configDBPath, l } // Initialize framework config and pricing manager - if err = initDefaultFrameworkConfig(ctx, config); err != nil { + if err = initDefaultFrameworkConfig(ctx, config, EnterpriseOverrides); err != nil { return nil, err } @@ -1804,7 +1867,7 @@ func loadDefaultEnvKeys(ctx context.Context, config *Config) error { } // initDefaultFrameworkConfig initializes framework configuration and pricing manager -func initDefaultFrameworkConfig(ctx context.Context, config *Config) error { +func initDefaultFrameworkConfig(ctx context.Context, config *Config, EnterpriseOverrides EnterpriseOverrides) error { frameworkConfig, err := config.ConfigStore.GetFrameworkConfig(ctx) if err != nil { logger.Warn("failed to get framework config from store: %v", err) @@ -1849,9 +1912,20 @@ func initDefaultFrameworkConfig(ctx context.Context, config *Config) error { } // Initialize pricing manager - pricingManager, err := modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, logger) - if err != nil { - logger.Warn("failed to initialize pricing manager: %v", err) + var pricingManager *modelcatalog.ModelCatalog + + // Check if EnterpriseOverrides is provided, otherwise use default initialization + if EnterpriseOverrides != nil { + pricingManager, err = EnterpriseOverrides.LoadPricingManager(ctx, pricingConfig, config.ConfigStore) + if err != nil { + logger.Warn("failed to initialize pricing manager: %v", err) + } + } else { + // Use default modelcatalog initialization when no enterprise overrides are provided + pricingManager, err = modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, nil, logger) + if err != nil { + logger.Warn("failed to initialize pricing manager: %v", err) + } } config.PricingManager = pricingManager return nil @@ -2737,7 +2811,7 @@ func (c *Config) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClien if err := c.client.AddMCPClient(c.MCPConfig.ClientConfigs[len(c.MCPConfig.ClientConfigs)-1]); err != nil { c.MCPConfig.ClientConfigs = c.MCPConfig.ClientConfigs[:len(c.MCPConfig.ClientConfigs)-1] c.cleanupEnvKeys("", clientConfig.ID, newEnvKeys) - return fmt.Errorf("failed to add MCP client: %w", err) + return fmt.Errorf("failed to connect MCP client: %w", err) } if c.ConfigStore != nil { @@ -2886,8 +2960,10 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig sch // Update the in-memory config with the processed values c.MCPConfig.ClientConfigs[configIndex].Name = processedConfig.Name + c.MCPConfig.ClientConfigs[configIndex].IsCodeModeClient = processedConfig.IsCodeModeClient c.MCPConfig.ClientConfigs[configIndex].Headers = processedConfig.Headers c.MCPConfig.ClientConfigs[configIndex].ToolsToExecute = processedConfig.ToolsToExecute + c.MCPConfig.ClientConfigs[configIndex].ToolsToAutoExecute = processedConfig.ToolsToAutoExecute // Check if client is registered in Bifrost (can be not registered if client initialization failed) if clients, err := c.client.GetMCPClients(); err == nil && len(clients) > 0 { @@ -2922,12 +2998,14 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig sch func (c *Config) RedactMCPClientConfig(config schemas.MCPClientConfig) schemas.MCPClientConfig { // Create a copy with basic fields configCopy := schemas.MCPClientConfig{ - ID: config.ID, - Name: config.Name, - ConnectionType: config.ConnectionType, - ConnectionString: config.ConnectionString, - StdioConfig: config.StdioConfig, - ToolsToExecute: append([]string{}, config.ToolsToExecute...), + ID: config.ID, + Name: config.Name, + IsCodeModeClient: config.IsCodeModeClient, + ConnectionType: config.ConnectionType, + ConnectionString: config.ConnectionString, + StdioConfig: config.StdioConfig, + ToolsToExecute: append([]string{}, config.ToolsToExecute...), + ToolsToAutoExecute: append([]string{}, config.ToolsToAutoExecute...), } // Handle connection string if present diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 709b36474..36e7476ed 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -133,6 +133,7 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/google/uuid" "github.com/maximhq/bifrost/core/schemas" @@ -190,6 +191,10 @@ func (m *MockConfigStore) RunMigration(ctx context.Context, migration *migrator. return nil } +func (m *MockConfigStore) RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) { + return fn(ctx) +} + // Client config func (m *MockConfigStore) UpdateClientConfig(ctx context.Context, config *configstore.ClientConfig) error { m.clientConfig = config @@ -309,6 +314,10 @@ func (m *MockConfigStore) GetRateLimit(ctx context.Context, id string) (*tables. return nil, nil } +func (m *MockConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) { + return []tables.TableRateLimit{}, nil +} + func (m *MockConfigStore) CreateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error { if m.governanceConfig == nil { m.governanceConfig = &configstore.GovernanceConfig{} @@ -615,12 +624,12 @@ func makeMCPClientConfig(id, name string) schemas.MCPClientConfig { // testLogger is a minimal logger implementation for testing type testLogger struct{} -func (l *testLogger) Debug(msg string, args ...any) {} -func (l *testLogger) Info(msg string, args ...any) {} -func (l *testLogger) Warn(msg string, args ...any) {} -func (l *testLogger) Error(msg string, args ...any) {} -func (l *testLogger) Fatal(msg string, args ...any) {} -func (l *testLogger) SetLevel(level schemas.LogLevel) {} +func (l *testLogger) Debug(msg string, args ...any) {} +func (l *testLogger) Info(msg string, args ...any) {} +func (l *testLogger) Warn(msg string, args ...any) {} +func (l *testLogger) Error(msg string, args ...any) {} +func (l *testLogger) Fatal(msg string, args ...any) {} +func (l *testLogger) SetLevel(level schemas.LogLevel) {} func (l *testLogger) SetOutputType(outputType schemas.LoggerOutputType) {} // initTestLogger initializes the global logger for SQLite integration tests @@ -1827,7 +1836,7 @@ func TestProviderHashComparison_DifferentHash(t *testing.T) { Weight: dbKey.Weight, }) fileKeyHash, _ := configstore.GenerateKeyHash(fileKey) - if dbKeyHash == fileKeyHash || fileKey.Name == dbKey.Name { + if dbKeyHash == fileKeyHash || fileKey.Name == dbKey.Name { found = true break } @@ -2188,13 +2197,13 @@ func TestProviderHashComparison_OptionalFieldsPresence(t *testing.T) { // All hashes should be unique hashes := map[string]string{ - "no_optional": hashNoOptional, - "with_network": hashWithNetwork, - "with_proxy": hashWithProxy, - "with_conc": hashWithConcurrency, - "with_custom": hashWithCustom, - "with_raw": hashWithRawResponse, - "all_fields": hashAllFields, + "no_optional": hashNoOptional, + "with_network": hashWithNetwork, + "with_proxy": hashWithProxy, + "with_conc": hashWithConcurrency, + "with_custom": hashWithCustom, + "with_raw": hashWithRawResponse, + "all_fields": hashAllFields, } seen := make(map[string]string) @@ -3000,9 +3009,9 @@ func TestProviderHashComparison_ProviderChangedKeysUnchanged(t *testing.T) { sameKey := schemas.Key{ ID: "key-1", Name: "openai-key", - Value: "sk-original-123", // SAME + Value: "sk-original-123", // SAME Models: []string{"gpt-4", "gpt-3.5-turbo"}, // SAME - Weight: 1.5, // SAME + Weight: 1.5, // SAME } sameKeyHash, _ := configstore.GenerateKeyHash(sameKey) @@ -3038,10 +3047,10 @@ func TestProviderHashComparison_ProviderChangedKeysUnchanged(t *testing.T) { // - Keep existing keys from DB (they weren't changed in file) updatedConfig := configstore.ProviderConfig{ - Keys: dbConfig.Keys, // Keep original keys from DB - NetworkConfig: fileConfig.NetworkConfig, // Update from file - SendBackRawResponse: fileConfig.SendBackRawResponse, // Update from file - ConfigHash: fileProviderHash, // New provider hash + Keys: dbConfig.Keys, // Keep original keys from DB + NetworkConfig: fileConfig.NetworkConfig, // Update from file + SendBackRawResponse: fileConfig.SendBackRawResponse, // Update from file + ConfigHash: fileProviderHash, // New provider hash } // Verify keys are preserved (same values as DB) @@ -3099,9 +3108,9 @@ func TestProviderHashComparison_KeysChangedProviderUnchanged(t *testing.T) { changedKey := schemas.Key{ ID: "key-1", Name: "openai-key", - Value: "sk-new-456", // CHANGED! - Models: []string{"gpt-4", "gpt-3.5-turbo", "o1"}, // CHANGED! - Weight: 2.0, // CHANGED! + Value: "sk-new-456", // CHANGED! + Models: []string{"gpt-4", "gpt-3.5-turbo", "o1"}, // CHANGED! + Weight: 2.0, // CHANGED! } changedKeyHash, _ := configstore.GenerateKeyHash(changedKey) @@ -3137,10 +3146,10 @@ func TestProviderHashComparison_KeysChangedProviderUnchanged(t *testing.T) { // - Update keys from file (they were changed) updatedConfig := configstore.ProviderConfig{ - Keys: fileConfig.Keys, // Update keys from file - NetworkConfig: dbConfig.NetworkConfig, // Keep from DB - SendBackRawResponse: dbConfig.SendBackRawResponse, // Keep from DB - ConfigHash: dbProviderHash, // Provider hash unchanged + Keys: fileConfig.Keys, // Update keys from file + NetworkConfig: dbConfig.NetworkConfig, // Keep from DB + SendBackRawResponse: dbConfig.SendBackRawResponse, // Keep from DB + ConfigHash: dbProviderHash, // Provider hash unchanged } // Verify provider config is preserved @@ -3199,9 +3208,9 @@ func TestProviderHashComparison_BothChangedIndependently(t *testing.T) { changedKey := schemas.Key{ ID: "key-1", Name: "openai-key", - Value: "sk-new-456", // CHANGED + Value: "sk-new-456", // CHANGED Models: []string{"gpt-4", "o1"}, // CHANGED - Weight: 2.0, // CHANGED + Weight: 2.0, // CHANGED } changedKeyHash, _ := configstore.GenerateKeyHash(changedKey) @@ -3301,7 +3310,7 @@ func TestProviderHashComparison_NeitherChanged(t *testing.T) { // === Verify: Both hashes match === if dbProviderHash != fileProviderHash { - t.Errorf("Expected provider hash to be SAME, got DB=%s File=%s", + t.Errorf("Expected provider hash to be SAME, got DB=%s File=%s", dbProviderHash[:16], fileProviderHash[:16]) } else { t.Log("✓ Provider hash unchanged") @@ -3351,9 +3360,9 @@ func TestKeyLevelSync_ProviderHashMatch_SingleKeyChanged(t *testing.T) { fileKey := schemas.Key{ ID: "key-1", Name: "openai-key", - Value: "sk-new-value", // CHANGED + Value: "sk-new-value", // CHANGED Models: []string{"gpt-4", "gpt-4-turbo"}, // CHANGED - Weight: 2.0, // CHANGED + Weight: 2.0, // CHANGED } fileKeyHash, _ := configstore.GenerateKeyHash(fileKey) @@ -3464,9 +3473,9 @@ func TestKeyLevelSync_ProviderHashMatch_NewKeyInFile(t *testing.T) { fileKey1 := schemas.Key{ ID: "key-1", Name: "openai-key-1", - Value: "sk-key-1", // SAME + Value: "sk-key-1", // SAME Models: []string{"gpt-4"}, // SAME - Weight: 1.0, // SAME + Weight: 1.0, // SAME } newFileKey := schemas.Key{ ID: "key-2", @@ -3593,9 +3602,9 @@ func TestKeyLevelSync_ProviderHashMatch_KeyOnlyInDB(t *testing.T) { fileKey1 := schemas.Key{ ID: "key-1", Name: "openai-key-1", - Value: "sk-key-1", // SAME + Value: "sk-key-1", // SAME Models: []string{"gpt-4"}, // SAME - Weight: 1.0, // SAME + Weight: 1.0, // SAME } fileConfig := configstore.ProviderConfig{ @@ -3718,16 +3727,16 @@ func TestKeyLevelSync_ProviderHashMatch_MixedScenario(t *testing.T) { fileUnchangedKey := schemas.Key{ ID: "key-unchanged", Name: "unchanged-key", - Value: "sk-unchanged", // SAME + Value: "sk-unchanged", // SAME Models: []string{"gpt-4"}, // SAME - Weight: 1.0, // SAME + Weight: 1.0, // SAME } fileChangedKey := schemas.Key{ ID: "key-changed", Name: "changed-key", - Value: "sk-NEW-value", // CHANGED + Value: "sk-NEW-value", // CHANGED Models: []string{"gpt-4", "gpt-4-turbo"}, // CHANGED - Weight: 2.0, // CHANGED + Weight: 2.0, // CHANGED } newFileKey := schemas.Key{ ID: "key-new", @@ -4861,8 +4870,8 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Endpoint: "https://new-azure.openai.azure.com", // Changed! APIVersion: stringPtr("2024-10-21"), // Changed! Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-4o": "gpt-4o-deployment", // Added! + "gpt-4": "gpt-4-deployment", + "gpt-4o": "gpt-4o-deployment", // Added! }, }, }, @@ -5080,7 +5089,7 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: "AKIAIOSFODNN7EXAMPLE", SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", - Region: stringPtr("us-west-2"), // Changed! + Region: stringPtr("us-west-2"), // Changed! ARN: stringPtr("arn:aws:bedrock:us-west-2:123456789012:inference-profile/my-profile"), // Added! Deployments: map[string]string{ "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", @@ -5091,7 +5100,7 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { }, NetworkConfig: &schemas.NetworkConfig{ BaseURL: "https://bedrock-runtime.us-west-2.amazonaws.com", // Changed! - MaxRetries: 5, // Changed! + MaxRetries: 5, // Changed! }, SendBackRawResponse: true, // Changed! } @@ -5517,9 +5526,9 @@ func TestProviderHashComparison_BedrockDBValuePreservedWhenHashMatches(t *testin Value: "", Weight: 1, BedrockKeyConfig: &schemas.BedrockKeyConfig{ - AccessKey: "AKIAIOSFODNN7EXAMPLE", // Different! - SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", // Different! - Region: stringPtr("us-east-1"), // Same + AccessKey: "AKIAIOSFODNN7EXAMPLE", // Different! + SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", // Different! + Region: stringPtr("us-east-1"), // Same Deployments: map[string]string{ "claude-3": "anthropic.claude-3-sonnet-20240229-v1:0", // Same }, @@ -5528,7 +5537,7 @@ func TestProviderHashComparison_BedrockDBValuePreservedWhenHashMatches(t *testin }, NetworkConfig: &schemas.NetworkConfig{ BaseURL: "https://bedrock-runtime.us-east-1.amazonaws.com", // Same - MaxRetries: 3, // Same + MaxRetries: 3, // Same }, SendBackRawResponse: false, // Same } @@ -5609,7 +5618,7 @@ func TestProviderHashComparison_AzureConfigChangedInFile(t *testing.T) { Weight: 1, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: "https://NEW-azure.openai.azure.com", // Changed! - APIVersion: stringPtr("2024-10-21"), // Changed! + APIVersion: stringPtr("2024-10-21"), // Changed! Deployments: map[string]string{ "gpt-4o": "gpt-4o-deployment", // Added! }, @@ -5700,7 +5709,7 @@ func TestProviderHashComparison_BedrockConfigChangedInFile(t *testing.T) { BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: "AKIAIOSFODNN7EXAMPLE", SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", - Region: stringPtr("us-west-2"), // Changed! + Region: stringPtr("us-west-2"), // Changed! ARN: stringPtr("arn:aws:bedrock:us-west-2:123456789012:inference-profile/new-profile"), // Added! Deployments: map[string]string{ "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0", // Added! @@ -5710,7 +5719,7 @@ func TestProviderHashComparison_BedrockConfigChangedInFile(t *testing.T) { }, NetworkConfig: &schemas.NetworkConfig{ BaseURL: "https://bedrock-runtime.us-west-2.amazonaws.com", // Changed! - MaxRetries: 5, // Changed! + MaxRetries: 5, // Changed! }, SendBackRawResponse: true, // Changed! } @@ -6663,7 +6672,7 @@ func TestSQLite_Provider_NewProviderFromFile(t *testing.T) { // Load config - this should create the provider in the DB ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -6701,7 +6710,7 @@ func TestSQLite_Provider_HashMatch_DBPreserved(t *testing.T) { // First load - creates provider in DB ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6712,7 +6721,7 @@ func TestSQLite_Provider_HashMatch_DBPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json - should preserve DB config - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -6744,7 +6753,7 @@ func TestSQLite_Provider_HashMismatch_FileSync(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6762,7 +6771,7 @@ func TestSQLite_Provider_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load with modified config.json - should sync from file - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -6796,7 +6805,7 @@ func TestSQLite_Provider_DBOnlyProvider_Preserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6828,7 +6837,7 @@ func TestSQLite_Provider_DBOnlyProvider_Preserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json (no Anthropic) - should preserve DB-added provider - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -6861,7 +6870,7 @@ func TestSQLite_Provider_RoundTrip(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6881,7 +6890,7 @@ func TestSQLite_Provider_RoundTrip(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json - should preserve DB changes since hash matches - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -6923,7 +6932,7 @@ func TestSQLite_Key_NewKeyFromFile(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -6958,7 +6967,7 @@ func TestSQLite_Key_HashMatch_DBKeyPreserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6969,7 +6978,7 @@ func TestSQLite_Key_HashMatch_DBKeyPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7008,7 +7017,7 @@ func TestSQLite_Key_DashboardAddedKey_Preserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7030,7 +7039,7 @@ func TestSQLite_Key_DashboardAddedKey_Preserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json (still has only file-key) - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7075,7 +7084,7 @@ func TestSQLite_Key_KeyValueChange_Detected(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7100,7 +7109,7 @@ func TestSQLite_Key_KeyValueChange_Detected(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load with modified config - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7135,7 +7144,7 @@ func TestSQLite_Key_MultipleKeys_MergeLogic(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7162,7 +7171,7 @@ func TestSQLite_Key_MultipleKeys_MergeLogic(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json (still has key-1 and key-2) - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7205,7 +7214,7 @@ func TestSQLite_VirtualKey_NewFromFile(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -7245,7 +7254,7 @@ func TestSQLite_VirtualKey_HashMatch_DBPreserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7256,7 +7265,7 @@ func TestSQLite_VirtualKey_HashMatch_DBPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7286,7 +7295,7 @@ func TestSQLite_VirtualKey_HashMismatch_FileSync(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7313,7 +7322,7 @@ func TestSQLite_VirtualKey_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load with modified config - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7348,7 +7357,7 @@ func TestSQLite_VirtualKey_DBOnlyVK_Preserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7370,7 +7379,7 @@ func TestSQLite_VirtualKey_DBOnlyVK_Preserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json (only has vk-file) - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7423,7 +7432,7 @@ func TestSQLite_VirtualKey_WithProviderConfigs(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -7484,7 +7493,7 @@ func TestSQLite_VirtualKey_MergePath_WithProviderConfigs(t *testing.T) { // First load - bootstrap path ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7515,7 +7524,7 @@ func TestSQLite_VirtualKey_MergePath_WithProviderConfigs(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - merge path (this is where the bug is) - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7586,7 +7595,7 @@ func TestSQLite_VirtualKey_MergePath_WithProviderConfigKeys(t *testing.T) { // First load - bootstrap path (creates provider with key in DB) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7631,7 +7640,7 @@ func TestSQLite_VirtualKey_MergePath_WithProviderConfigKeys(t *testing.T) { // Second load - merge path // BEFORE FIX: This would fail because GORM tries to INSERT the key again // AFTER FIX: CreateVirtualKeyProviderConfig uses Append() to associate existing keys - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7777,7 +7786,7 @@ func TestSQLite_VKProviderConfig_NewConfig(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -7849,7 +7858,7 @@ func TestSQLite_VKProviderConfig_KeyReference(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -8157,7 +8166,7 @@ func TestSQLite_FullLifecycle_InitialLoad(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -8216,7 +8225,7 @@ func TestSQLite_FullLifecycle_SecondLoadNoChanges(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8229,7 +8238,7 @@ func TestSQLite_FullLifecycle_SecondLoadNoChanges(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -8271,7 +8280,7 @@ func TestSQLite_FullLifecycle_FileChange_Selective(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8304,7 +8313,7 @@ func TestSQLite_FullLifecycle_FileChange_Selective(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -8360,7 +8369,7 @@ func TestSQLite_FullLifecycle_DashboardEdits_ThenFileUnchanged(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8398,7 +8407,7 @@ func TestSQLite_FullLifecycle_DashboardEdits_ThenFileUnchanged(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with SAME config.json (unchanged) - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -8617,7 +8626,7 @@ func TestSQLite_VirtualKey_WithMCPConfigs(t *testing.T) { createConfigFile(t, tempDir, configData) // First load - creates VK - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8706,7 +8715,7 @@ func TestSQLite_VKMCPConfig_Reconciliation(t *testing.T) { createConfigFile(t, tempDir, configData) // First load - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8783,7 +8792,7 @@ func TestSQLite_VKMCPConfig_Reconciliation(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - should trigger reconciliation - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -8885,7 +8894,7 @@ func TestSQLite_VirtualKey_DashboardProviderConfig_PreservedOnFileChange(t *test createConfigFile(t, tempDir, configData) // Step 2: First load - bootstrap path - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8949,7 +8958,7 @@ func TestSQLite_VirtualKey_DashboardProviderConfig_PreservedOnFileChange(t *test createConfigFile(t, tempDir, configData2) // Step 5: Second load - merge path with hash mismatch - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -9034,7 +9043,7 @@ func TestSQLite_VirtualKey_DashboardMCPConfig_PreservedOnFileChange(t *testing.T createConfigFile(t, tempDir, configData) // Step 2: First load - bootstrap path - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -9125,7 +9134,7 @@ func TestSQLite_VirtualKey_DashboardMCPConfig_PreservedOnFileChange(t *testing.T createConfigFile(t, tempDir, configData2) // Step 5: Second load - merge path with hash mismatch - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -9205,7 +9214,7 @@ func TestSQLite_VKMCPConfig_AddRemove(t *testing.T) { createConfigFile(t, tempDir, configData) // First load - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -9244,7 +9253,7 @@ func TestSQLite_VKMCPConfig_AddRemove(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - should add MCP configs - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -9277,7 +9286,7 @@ func TestSQLite_VKMCPConfig_AddRemove(t *testing.T) { // Third load - mcpClient2 config should be PRESERVED (not deleted) // This protects dashboard-added configs from accidental deletion - config3, err := LoadConfig(ctx, tempDir) + config3, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Third LoadConfig failed: %v", err) } @@ -9328,7 +9337,7 @@ func TestSQLite_VKMCPConfig_UpdateTools(t *testing.T) { createConfigFile(t, tempDir, configData) // First load - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -9378,7 +9387,7 @@ func TestSQLite_VKMCPConfig_UpdateTools(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - should update tools - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -9422,7 +9431,7 @@ func TestSQLite_VK_ProviderAndMCPConfigs_Combined(t *testing.T) { createConfigFile(t, tempDir, configData) // First load to set up DB - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -9460,7 +9469,7 @@ func TestSQLite_VK_ProviderAndMCPConfigs_Combined(t *testing.T) { createConfigFile(t, tempDir, configData2) // Load config - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10253,7 +10262,7 @@ func TestSQLite_Budget_NewFromFile(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10296,7 +10305,7 @@ func TestSQLite_Budget_HashMatch_DBPreserved(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10307,7 +10316,7 @@ func TestSQLite_Budget_HashMatch_DBPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load - same config - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10335,7 +10344,7 @@ func TestSQLite_Budget_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10346,7 +10355,7 @@ func TestSQLite_Budget_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) // Second load - should sync from file - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10374,7 +10383,7 @@ func TestSQLite_Budget_DBOnly_Preserved(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10391,7 +10400,7 @@ func TestSQLite_Budget_DBOnly_Preserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Reload - dashboard budget should be preserved - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10525,7 +10534,7 @@ func TestSQLite_RateLimit_NewFromFile(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10563,7 +10572,7 @@ func TestSQLite_RateLimit_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10575,7 +10584,7 @@ func TestSQLite_RateLimit_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) // Second load - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10668,7 +10677,7 @@ func TestSQLite_Customer_NewFromFile(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10702,7 +10711,7 @@ func TestSQLite_Customer_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10712,7 +10721,7 @@ func TestSQLite_Customer_HashMismatch_FileSync(t *testing.T) { configData.Governance.Customers[0].Name = "Updated Customer" createConfigFile(t, tempDir, configData) - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10843,7 +10852,7 @@ func TestSQLite_Team_NewFromFile(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config, err := LoadConfig(ctx, tempDir) + config, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10878,7 +10887,7 @@ func TestSQLite_Team_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10888,7 +10897,7 @@ func TestSQLite_Team_HashMismatch_FileSync(t *testing.T) { configData.Governance.Teams[0].Name = "Updated Team" createConfigFile(t, tempDir, configData) - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -11267,7 +11276,7 @@ func TestSQLite_Governance_FullReconciliation(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -11297,7 +11306,7 @@ func TestSQLite_Governance_FullReconciliation(t *testing.T) { createConfigFile(t, tempDir, configData) // Reload and verify all entities are updated - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -11332,7 +11341,7 @@ func TestSQLite_Governance_DBOnly_AllPreserved(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir) + config1, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -11368,7 +11377,7 @@ func TestSQLite_Governance_DBOnly_AllPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Reload - all dashboard entities should be preserved - config2, err := LoadConfig(ctx, tempDir) + config2, err := LoadConfig(ctx, tempDir, nil) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -11471,7 +11480,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for StdioConfig\nBefore save: %s\nAfter load: %s\nStdioConfig populated: %v", + t.Errorf("Hash mismatch after GORM round-trip for StdioConfig\nBefore save: %s\nAfter load: %s\nStdioConfig populated: %v", hashBeforeSave, hashAfterLoad, mcpFromDB.StdioConfig != nil) } }) @@ -11502,7 +11511,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for ToolsToExecute\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch after GORM round-trip for ToolsToExecute\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11536,7 +11545,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for Headers\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch after GORM round-trip for Headers\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11568,7 +11577,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GenerateMCPClientHash(mcpFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for all fields\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch after GORM round-trip for all fields\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11607,7 +11616,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch when using Find() (migration pattern)\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch when using Find() (migration pattern)\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11653,7 +11662,7 @@ func TestGeneratePluginHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for plugin Config\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch after GORM round-trip for plugin Config\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11683,7 +11692,7 @@ func TestGeneratePluginHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GeneratePluginHash(pluginFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for nested config\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for nested config\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11706,7 +11715,7 @@ func TestGeneratePluginHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GeneratePluginHash(pluginFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for empty config\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for empty config\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11747,7 +11756,7 @@ func TestGenerateTeamHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for Profile\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for Profile\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11774,7 +11783,7 @@ func TestGenerateTeamHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GenerateTeamHash(teamFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for Config\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for Config\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11801,7 +11810,7 @@ func TestGenerateTeamHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GenerateTeamHash(teamFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for Claims\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for Claims\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11830,7 +11839,7 @@ func TestGenerateTeamHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GenerateTeamHash(teamFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for all fields\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for all fields\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11882,7 +11891,7 @@ func TestGenerateProviderHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for NetworkConfig\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for NetworkConfig\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11922,7 +11931,7 @@ func TestGenerateProviderHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for ConcurrencyAndBufferSize\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for ConcurrencyAndBufferSize\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11958,7 +11967,7 @@ func TestGenerateProviderHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := providerConfigFromDB.GenerateConfigHash("openai") if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for ProxyConfig\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for ProxyConfig\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11994,7 +12003,7 @@ func TestGenerateProviderHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := providerConfigFromDB.GenerateConfigHash("custom") if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for CustomProviderConfig\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for CustomProviderConfig\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -12055,7 +12064,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for Models\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for Models\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -12105,7 +12114,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for AzureKeyConfig\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for AzureKeyConfig\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -12206,7 +12215,7 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for PrometheusLabels\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for PrometheusLabels\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -12254,7 +12263,7 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for AllowedOrigins\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for AllowedOrigins\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) diff --git a/transports/bifrost-http/lib/lib.go b/transports/bifrost-http/lib/lib.go index 4669aca21..6562d1c72 100644 --- a/transports/bifrost-http/lib/lib.go +++ b/transports/bifrost-http/lib/lib.go @@ -1,7 +1,11 @@ package lib import ( + "context" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/maximhq/bifrost/framework/modelcatalog" ) var logger schemas.Logger @@ -10,3 +14,9 @@ var logger schemas.Logger func SetLogger(l schemas.Logger) { logger = l } + +type EnterpriseOverrides interface { + GetGovernancePluginName() string + LoadGovernancePlugin(ctx context.Context, config *Config) (schemas.Plugin, error) + LoadPricingManager(ctx context.Context, pricingConfig *modelcatalog.Config, configStore configstore.ConfigStore) (*modelcatalog.ModelCatalog, error) +} diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 46e6d6f29..802f7401b 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -18,11 +18,13 @@ import ( "github.com/bytedance/sonic" "github.com/fasthttp/router" + "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/framework/modelcatalog" dynamicPlugins "github.com/maximhq/bifrost/framework/plugins" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/logging" @@ -64,22 +66,19 @@ type ServerCallbacks interface { ForceReloadPricing(ctx context.Context) error ReloadProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error UpdateDropExcessRequests(ctx context.Context, value bool) + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error ReloadTeam(ctx context.Context, id string) (*tables.TableTeam, error) RemoveTeam(ctx context.Context, id string) error ReloadCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) RemoveCustomer(ctx context.Context, id string) error ReloadVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) RemoveVirtualKey(ctx context.Context, id string) error + GetGovernanceData() *governance.GovernanceData AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error RemoveMCPClient(ctx context.Context, id string) error EditMCPClient(ctx context.Context, id string, updatedConfig schemas.MCPClientConfig) error } -var ( - BifrostContextKeyBudgetIDs schemas.BifrostContextKey = "budget_ids" - BifrostContextKeyBudgetID schemas.BifrostContextKey = "budget_id" -) - // BifrostHTTPServer represents a HTTP server instance. type BifrostHTTPServer struct { ctx context.Context @@ -103,10 +102,12 @@ type BifrostHTTPServer struct { Client *bifrost.Bifrost Config *lib.Config - Server *fasthttp.Server - Router *router.Router - WebSocketHandler *handlers.WebSocketHandler - LogsCleaner *logstore.LogsCleaner + OSSToEnterprisePluginNameOverrides map[string]string + Server *fasthttp.Server + Router *router.Router + WebSocketHandler *handlers.WebSocketHandler + LogsCleaner *logstore.LogsCleaner + MCPServerHandler *handlers.MCPServerHandler } var logger schemas.Logger @@ -198,18 +199,18 @@ func MarshalPluginConfig[T any](source any) (*T, error) { } type GovernanceInMemoryStore struct { - config *lib.Config + Config *lib.Config } func (s *GovernanceInMemoryStore) GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig { // Use read lock for thread-safe access - no need to copy on hot path - s.config.Mu.RLock() - defer s.config.Mu.RUnlock() - return s.config.Providers + s.Config.Mu.RLock() + defer s.Config.Mu.RUnlock() + return s.Config.Providers } // LoadPlugin loads a plugin by name and returns it as type T. -func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string, pluginConfig any, bifrostConfig *lib.Config) (T, error) { +func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string, pluginConfig any, bifrostConfig *lib.Config, EnterpriseOverrides lib.EnterpriseOverrides) (T, error) { var zero T if path != nil { logger.Info("loading dynamic plugin %s from path %s", name, *path) @@ -260,13 +261,13 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string return p, nil } return zero, fmt.Errorf("logging plugin type mismatch") - case governance.PluginName: + case EnterpriseOverrides.GetGovernancePluginName(): governanceConfig, err := MarshalPluginConfig[governance.Config](pluginConfig) if err != nil { return zero, fmt.Errorf("failed to marshal governance plugin config: %v", err) } inMemoryStore := &GovernanceInMemoryStore{ - config: bifrostConfig, + Config: bifrostConfig, } plugin, err := governance.Init(ctx, governanceConfig, logger, bifrostConfig.ConfigStore, bifrostConfig.GovernanceConfig, bifrostConfig.PricingManager, inMemoryStore) if err != nil { @@ -321,12 +322,12 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string } // LoadPlugins loads the plugins for the server. -func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, []schemas.PluginStatus, error) { +func LoadPlugins(ctx context.Context, config *lib.Config, EnterpriseOverrides lib.EnterpriseOverrides) ([]schemas.Plugin, []schemas.PluginStatus, error) { var err error pluginStatus := []schemas.PluginStatus{} plugins := []schemas.Plugin{} // Initialize telemetry plugin - promPlugin, err := LoadPlugin[*telemetry.PrometheusPlugin](ctx, telemetry.PluginName, nil, nil, config) + promPlugin, err := LoadPlugin[*telemetry.PrometheusPlugin](ctx, telemetry.PluginName, nil, nil, config, EnterpriseOverrides) if err != nil { logger.Error("failed to initialize telemetry plugin: %v", err) pluginStatus = append(pluginStatus, schemas.PluginStatus{ @@ -348,7 +349,7 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, []s // Use dedicated logs database with high-scale optimizations loggingPlugin, err = LoadPlugin[*logging.LoggerPlugin](ctx, logging.PluginName, nil, &logging.Config{ DisableContentLogging: &config.ClientConfig.DisableContentLogging, - }, config) + }, config, EnterpriseOverrides) if err != nil { logger.Error("failed to initialize logging plugin: %v", err) pluginStatus = append(pluginStatus, schemas.PluginStatus{ @@ -372,37 +373,34 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, []s }) } // Initializing governance plugin - var governancePlugin *governance.GovernancePlugin if config.ClientConfig.EnableGovernance { // Initialize governance plugin - governancePlugin, err = LoadPlugin[*governance.GovernancePlugin](ctx, governance.PluginName, nil, &governance.Config{ - IsVkMandatory: &config.ClientConfig.EnforceGovernanceHeader, - }, config) + governancePlugin, err := EnterpriseOverrides.LoadGovernancePlugin(ctx, config) if err != nil { logger.Error("failed to initialize governance plugin: %s", err.Error()) pluginStatus = append(pluginStatus, schemas.PluginStatus{ - Name: governance.PluginName, + Name: EnterpriseOverrides.GetGovernancePluginName(), Status: schemas.PluginStatusError, Logs: []string{fmt.Sprintf("error initializing governance plugin %v", err)}, }) - } else { + } else if governancePlugin != nil { plugins = append(plugins, governancePlugin) pluginStatus = append(pluginStatus, schemas.PluginStatus{ - Name: governance.PluginName, + Name: EnterpriseOverrides.GetGovernancePluginName(), Status: schemas.PluginStatusActive, Logs: []string{"governance plugin initialized successfully"}, }) } } else { pluginStatus = append(pluginStatus, schemas.PluginStatus{ - Name: governance.PluginName, + Name: EnterpriseOverrides.GetGovernancePluginName(), Status: schemas.PluginStatusDisabled, Logs: []string{"governance plugin disabled"}, }) } for _, plugin := range config.PluginConfigs { // Skip built-in plugins that are already handled above - if plugin.Name == telemetry.PluginName || plugin.Name == logging.PluginName || plugin.Name == governance.PluginName { + if plugin.Name == telemetry.PluginName || plugin.Name == logging.PluginName || plugin.Name == EnterpriseOverrides.GetGovernancePluginName() { continue } if !plugin.Enabled { @@ -413,7 +411,7 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, []s }) continue } - pluginInstance, err := LoadPlugin[schemas.Plugin](ctx, plugin.Name, plugin.Path, plugin.Config, config) + pluginInstance, err := LoadPlugin[schemas.Plugin](ctx, plugin.Name, plugin.Path, plugin.Config, config, EnterpriseOverrides) if err != nil { if slices.Contains(enterprisePlugins, plugin.Name) { continue @@ -458,60 +456,110 @@ func FindPluginByName[T schemas.Plugin](plugins []schemas.Plugin, name string) ( // AddMCPClient adds a new MCP client to the in-memory store func (s *BifrostHTTPServer) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error { - return s.Config.AddMCPClient(ctx, clientConfig) + if err := s.Config.AddMCPClient(ctx, clientConfig); err != nil { + return err + } + if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { + logger.Warn("failed to sync MCP servers after adding client: %v", err) + } + return nil +} + +// EditMCPClient edits an MCP client in the in-memory store +func (s *BifrostHTTPServer) EditMCPClient(ctx context.Context, id string, updatedConfig schemas.MCPClientConfig) error { + if err := s.Config.EditMCPClient(ctx, id, updatedConfig); err != nil { + return err + } + if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { + logger.Warn("failed to sync MCP servers after editing client: %v", err) + } + return nil } // RemoveMCPClient removes an MCP client from the in-memory store func (s *BifrostHTTPServer) RemoveMCPClient(ctx context.Context, id string) error { - return s.Config.RemoveMCPClient(ctx, id) + if err := s.Config.RemoveMCPClient(ctx, id); err != nil { + return err + } + if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { + logger.Warn("failed to sync MCP servers after removing client: %v", err) + } + return nil } -// EditMCPClient edits an MCP client in the in-memory store -func (s *BifrostHTTPServer) EditMCPClient(ctx context.Context, id string, updatedConfig schemas.MCPClientConfig) error { - return s.Config.EditMCPClient(ctx, id, updatedConfig) +// ExecuteChatMCPTool executes an MCP tool call and returns the result as a chat message. +func (s *BifrostHTTPServer) ExecuteChatMCPTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) { + return s.Client.ExecuteChatMCPTool(ctx, toolCall) +} + +// ExecuteResponsesMCPTool executes an MCP tool call and returns the result as a responses message. +func (s *BifrostHTTPServer) ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError) { + return s.Client.ExecuteResponsesMCPTool(ctx, toolCall) +} + +func (s *BifrostHTTPServer) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { + return s.Client.GetAvailableMCPTools(ctx) +} + +// getGovernancePlugin safely retrieves the governance plugin with proper locking. +// It acquires a read lock, finds the plugin, releases the lock, performs type assertion, +// and returns the BaseGovernancePlugin implementation or an error. +func (s *BifrostHTTPServer) getGovernancePlugin() (governance.BaseGovernancePlugin, error) { + s.PluginsMutex.RLock() + plugin, err := FindPluginByName[schemas.Plugin](s.Plugins, s.GetGovernancePluginName()) + s.PluginsMutex.RUnlock() + if err != nil { + return nil, err + } + if plugin == nil { + return nil, fmt.Errorf("governance plugin not found") + } + governancePlugin, ok := plugin.(governance.BaseGovernancePlugin) + if !ok { + return nil, fmt.Errorf("governance plugin does not implement BaseGovernancePlugin") + } + return governancePlugin, nil } // ReloadVirtualKey reloads a virtual key from the in-memory store func (s *BifrostHTTPServer) ReloadVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) { // Load relationships for response - preloadedVk, err := s.Config.ConfigStore.GetVirtualKey(ctx, id) + preloadedVk, err := s.Config.ConfigStore.RetryOnNotFound(ctx, func(ctx context.Context) (any, error) { + preloadedVk, err := s.Config.ConfigStore.GetVirtualKey(ctx, id) + if err != nil { + return nil, err + } + return preloadedVk, nil + }, lib.DBLookupMaxRetries, lib.DBLookupDelay) if err != nil { - logger.Error("failed to load relationships for created VK: %v", err) + logger.Error("failed to load virtual key: %v", err) return nil, err } - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) - if err != nil { - return nil, err + if preloadedVk == nil { + logger.Error("virtual key not found") + return nil, fmt.Errorf("virtual key not found") } - if governancePlugin == nil { - return nil, fmt.Errorf("governance plugin not found") + // Type assertion (should never happen) + virtualKey, ok := preloadedVk.(*tables.TableVirtualKey) + if !ok { + logger.Error("virtual key type assertion failed") + return nil, fmt.Errorf("virtual key type assertion failed") } - // Add to in-memory store - governancePlugin.GetGovernanceStore().UpdateVirtualKeyInMemory(preloadedVk) - // If budget was created, add it to in-memory store - if preloadedVk.BudgetID != nil && preloadedVk.Budget != nil { - governancePlugin.GetGovernanceStore().UpdateBudgetInMemory(preloadedVk.Budget) - } - // Add provider-level budgets to in-memory store - if preloadedVk.ProviderConfigs != nil { - for _, pc := range preloadedVk.ProviderConfigs { - if pc.BudgetID != nil && pc.Budget != nil { - governancePlugin.GetGovernanceStore().UpdateBudgetInMemory(pc.Budget) - } - } + governancePlugin, err := s.getGovernancePlugin() + if err != nil { + return nil, err } - return preloadedVk, nil + governancePlugin.GetGovernanceStore().UpdateVirtualKeyInMemory(virtualKey, nil, nil, nil) + s.MCPServerHandler.SyncVKMCPServer(virtualKey) + return virtualKey, nil } // RemoveVirtualKey removes a virtual key from the in-memory store func (s *BifrostHTTPServer) RemoveVirtualKey(ctx context.Context, id string) error { - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return err } - if governancePlugin == nil { - return fmt.Errorf("governance plugin not found") - } preloadedVk, err := s.Config.ConfigStore.GetVirtualKey(ctx, id) if err != nil { if !errors.Is(err, configstore.ErrNotFound) { @@ -521,26 +569,10 @@ func (s *BifrostHTTPServer) RemoveVirtualKey(ctx context.Context, id string) err if preloadedVk == nil { // This could be broadcast message from other server, so we will just clean up in-memory store governancePlugin.GetGovernanceStore().DeleteVirtualKeyInMemory(id) - if budgetIDs, ok := ctx.Value(BifrostContextKeyBudgetIDs).([]string); ok { - for _, budgetID := range budgetIDs { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(budgetID) - } - } return nil } governancePlugin.GetGovernanceStore().DeleteVirtualKeyInMemory(id) - // If budget was created, delete it from in-memory store - if preloadedVk.BudgetID != nil && preloadedVk.Budget != nil { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(*preloadedVk.BudgetID) - } - // Delete provider-level budgets from in-memory store - if preloadedVk.ProviderConfigs != nil { - for _, pc := range preloadedVk.ProviderConfigs { - if pc.BudgetID != nil && pc.Budget != nil { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(*pc.BudgetID) - } - } - } + s.MCPServerHandler.DeleteVKMCPServer(preloadedVk.Value) return nil } @@ -552,31 +584,21 @@ func (s *BifrostHTTPServer) ReloadTeam(ctx context.Context, id string) (*tables. logger.Error("failed to load relationships for created team: %v", err) return nil, err } - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return nil, err } - if governancePlugin == nil { - return nil, fmt.Errorf("governance plugin not found") - } // Add to in-memory store - governancePlugin.GetGovernanceStore().UpdateTeamInMemory(preloadedTeam) - // If budget was created, add it to in-memory store - if preloadedTeam.BudgetID != nil && preloadedTeam.Budget != nil { - governancePlugin.GetGovernanceStore().UpdateBudgetInMemory(preloadedTeam.Budget) - } + governancePlugin.GetGovernanceStore().UpdateTeamInMemory(preloadedTeam, nil) return preloadedTeam, nil } // RemoveTeam removes a team from the in-memory store func (s *BifrostHTTPServer) RemoveTeam(ctx context.Context, id string) error { - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return err } - if governancePlugin == nil { - return fmt.Errorf("governance plugin not found") - } preloadedTeam, err := s.Config.ConfigStore.GetTeam(ctx, id) if err != nil { if !errors.Is(err, configstore.ErrNotFound) { @@ -586,16 +608,9 @@ func (s *BifrostHTTPServer) RemoveTeam(ctx context.Context, id string) error { if preloadedTeam == nil { // At-least deleting from in-memory store to avoid conflicts governancePlugin.GetGovernanceStore().DeleteTeamInMemory(id) - if budgetID, ok := ctx.Value(BifrostContextKeyBudgetID).(string); ok { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(budgetID) - } return nil } governancePlugin.GetGovernanceStore().DeleteTeamInMemory(id) - // If budget was created, delete it from in-memory store - if preloadedTeam.BudgetID != nil && preloadedTeam.Budget != nil { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(*preloadedTeam.BudgetID) - } return nil } @@ -605,31 +620,21 @@ func (s *BifrostHTTPServer) ReloadCustomer(ctx context.Context, id string) (*tab if err != nil { return nil, err } - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return nil, err } - if governancePlugin == nil { - return nil, fmt.Errorf("governance plugin not found") - } // Add to in-memory store - governancePlugin.GetGovernanceStore().UpdateCustomerInMemory(preloadedCustomer) - // If budget was created, add it to in-memory store - if preloadedCustomer.BudgetID != nil && preloadedCustomer.Budget != nil { - governancePlugin.GetGovernanceStore().UpdateBudgetInMemory(preloadedCustomer.Budget) - } + governancePlugin.GetGovernanceStore().UpdateCustomerInMemory(preloadedCustomer, nil) return preloadedCustomer, nil } // RemoveCustomer removes a customer from the in-memory store func (s *BifrostHTTPServer) RemoveCustomer(ctx context.Context, id string) error { - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return err } - if governancePlugin == nil { - return fmt.Errorf("governance plugin not found") - } preloadedCustomer, err := s.Config.ConfigStore.GetCustomer(ctx, id) if err != nil { if !errors.Is(err, configstore.ErrNotFound) { @@ -639,15 +644,23 @@ func (s *BifrostHTTPServer) RemoveCustomer(ctx context.Context, id string) error if preloadedCustomer == nil { // At-least deleting from in-memory store to avoid conflicts governancePlugin.GetGovernanceStore().DeleteCustomerInMemory(id) - if budgetID, ok := ctx.Value(BifrostContextKeyBudgetID).(string); ok { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(budgetID) - } return nil } governancePlugin.GetGovernanceStore().DeleteCustomerInMemory(id) - // If budget was created, delete it from in-memory store - if preloadedCustomer.BudgetID != nil && preloadedCustomer.Budget != nil { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(*preloadedCustomer.BudgetID) + return nil +} + +// GetGovernanceData returns the governance data +func (s *BifrostHTTPServer) GetGovernanceData() *governance.GovernanceData { + s.PluginsMutex.RLock() + governancePlugin, err := FindPluginByName[schemas.Plugin](s.Plugins, s.GetGovernancePluginName()) + s.PluginsMutex.RUnlock() + if err != nil { + return nil + } + // Check if GetGovernanceStore method is implemented + if governancePlugin, ok := governancePlugin.(governance.BaseGovernancePlugin); ok { + return governancePlugin.GetGovernanceStore().GetGovernanceData() } return nil } @@ -699,6 +712,14 @@ func (s *BifrostHTTPServer) UpdateDropExcessRequests(ctx context.Context, value s.Client.UpdateDropExcessRequests(value) } +// UpdateMCPToolManagerConfig updates the MCP tool manager config +func (s *BifrostHTTPServer) UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { + if s.Config == nil { + return fmt.Errorf("config not found") + } + return s.Client.UpdateToolManagerConfig(maxAgentDepth, toolExecutionTimeoutInSeconds, codeModeBindingLevel) +} + // UpdatePluginStatus updates the status of a plugin func (s *BifrostHTTPServer) UpdatePluginStatus(name string, status string, logs []string) error { s.pluginStatusMutex.Lock() @@ -799,7 +820,7 @@ func (s *BifrostHTTPServer) SyncLoadedPlugin(ctx context.Context, plugin schemas // Uses atomic CompareAndSwap with retry loop to handle concurrent updates safely. func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error { logger.Debug("reloading plugin %s", name) - newPlugin, err := LoadPlugin[schemas.Plugin](ctx, name, path, pluginConfig, s.Config) + newPlugin, err := LoadPlugin[schemas.Plugin](ctx, name, path, pluginConfig, s.Config, s) if err != nil { s.UpdatePluginStatus(name, schemas.PluginStatusError, []string{fmt.Sprintf("error loading plugin %s: %v", name, err)}) return err @@ -925,7 +946,7 @@ func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlew } // RegisterAPIRoutes initializes the routes for the Bifrost HTTP server. -func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks ServerCallbacks, middlewares ...lib.BifrostHTTPMiddleware) error { +func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks ServerCallbacks, EnterpriseOverrides lib.EnterpriseOverrides, middlewares ...lib.BifrostHTTPMiddleware) error { var err error // Initializing plugin specific handlers var loggingHandler *handlers.LoggingHandler @@ -934,7 +955,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser loggingHandler = handlers.NewLoggingHandler(loggerPlugin.GetPluginLogManager(), s) } var governanceHandler *handlers.GovernanceHandler - governancePlugin, _ := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, _ := FindPluginByName[schemas.Plugin](s.Plugins, EnterpriseOverrides.GetGovernancePluginName()) if governancePlugin != nil { governanceHandler, err = handlers.NewGovernanceHandler(callbacks, s.Config.ConfigStore) if err != nil { @@ -962,6 +983,11 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser healthHandler := handlers.NewHealthHandler(s.Config) providerHandler := handlers.NewProviderHandler(callbacks, s.Config, s.Client) mcpHandler := handlers.NewMCPHandler(callbacks, s.Client, s.Config) + mcpServerHandler, err := handlers.NewMCPServerHandler(ctx, s.Config, s) + if err != nil { + return fmt.Errorf("failed to initialize mcp server handler: %v", err) + } + s.MCPServerHandler = mcpServerHandler configHandler := handlers.NewConfigHandler(callbacks, s.Config) pluginsHandler := handlers.NewPluginsHandler(callbacks, s.Config.ConfigStore) sessionHandler := handlers.NewSessionHandler(s.Config.ConfigStore) @@ -969,6 +995,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser healthHandler.RegisterRoutes(s.Router, middlewares...) providerHandler.RegisterRoutes(s.Router, middlewares...) mcpHandler.RegisterRoutes(s.Router, middlewares...) + mcpServerHandler.RegisterRoutes(s.Router, middlewares...) configHandler.RegisterRoutes(s.Router, middlewares...) if pluginsHandler != nil { pluginsHandler.RegisterRoutes(s.Router, middlewares...) @@ -1036,6 +1063,34 @@ func (s *BifrostHTTPServer) GetAllRedactedVirtualKeys(ctx context.Context, ids [ return virtualKeys } +func (s *BifrostHTTPServer) GetGovernancePluginName() string { + if s.OSSToEnterprisePluginNameOverrides != nil { + if name, ok := s.OSSToEnterprisePluginNameOverrides[governance.PluginName]; ok && name != "" { + return name + } + } + return governance.PluginName +} + +func (s *BifrostHTTPServer) LoadGovernancePlugin(ctx context.Context, config *lib.Config) (schemas.Plugin, error) { + governancePlugin, err := LoadPlugin[*governance.GovernancePlugin](ctx, governance.PluginName, nil, &governance.Config{ + IsVkMandatory: &config.ClientConfig.EnforceGovernanceHeader, + }, config, s) + + if err != nil { + return nil, fmt.Errorf("failed to initialize governance plugin: %v", err) + } + return governancePlugin, nil +} + +func (s *BifrostHTTPServer) LoadPricingManager(ctx context.Context, pricingConfig *modelcatalog.Config, configStore configstore.ConfigStore) (*modelcatalog.ModelCatalog, error) { + pricingManager, err := modelcatalog.Init(ctx, pricingConfig, configStore, nil, logger) + if err != nil { + return nil, fmt.Errorf("failed to initialize pricing manager: %v", err) + } + return pricingManager, nil +} + // PrepareCommonMiddlewares gets the common middlewares for the Bifrost HTTP server func (s *BifrostHTTPServer) PrepareCommonMiddlewares() []lib.BifrostHTTPMiddleware { commonMiddlewares := []lib.BifrostHTTPMiddleware{} @@ -1073,7 +1128,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { return fmt.Errorf("failed to create app directory %s: %v", configDir, err) } // Initialize high-performance configuration store with dedicated database - s.Config, err = lib.LoadConfig(ctx, configDir) + s.Config, err = lib.LoadConfig(ctx, configDir, s) if err != nil { return fmt.Errorf("failed to load config %v", err) } @@ -1112,10 +1167,16 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { // Load plugins s.pluginStatusMutex.Lock() defer s.pluginStatusMutex.Unlock() - s.Plugins, s.pluginStatus, err = LoadPlugins(ctx, s.Config) + s.Plugins, s.pluginStatus, err = LoadPlugins(ctx, s.Config, s) if err != nil { return fmt.Errorf("failed to load plugins %v", err) } + mcpConfig := s.Config.MCPConfig + if mcpConfig != nil { + mcpConfig.FetchNewRequestIDFunc = func(ctx context.Context) string { + return uuid.New().String() + } + } // Initialize bifrost client // Create account backed by the high-performance store (all processing is done in LoadFromDatabase) // The account interface now benefits from ultra-fast config access times via in-memory storage @@ -1125,7 +1186,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { InitialPoolSize: s.Config.ClientConfig.InitialPoolSize, DropExcessRequests: s.Config.ClientConfig.DropExcessRequests, Plugins: s.Plugins, - MCPConfig: s.Config.MCPConfig, + MCPConfig: mcpConfig, Logger: logger, }) if err != nil { @@ -1166,7 +1227,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { apiMiddlewares = append(apiMiddlewares, handlers.AuthMiddleware(s.Config.ConfigStore)) } // Register routes - err = s.RegisterAPIRoutes(s.ctx, s, apiMiddlewares...) + err = s.RegisterAPIRoutes(s.ctx, s, s, apiMiddlewares...) if err != nil { return fmt.Errorf("failed to initialize routes: %v", err) } @@ -1175,7 +1236,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { inferenceMiddlewares = append(inferenceMiddlewares, handlers.AuthMiddleware(s.Config.ConfigStore)) } // Registering inference middlewares - inferenceMiddlewares = append([]lib.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config)}, inferenceMiddlewares...) + inferenceMiddlewares = append([]lib.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config, s)}, inferenceMiddlewares...) err = s.RegisterInferenceRoutes(s.ctx, inferenceMiddlewares...) if err != nil { return fmt.Errorf("failed to initialize inference routes: %v", err) diff --git a/transports/changelog.md b/transports/changelog.md index e69de29bb..f15357fd8 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -0,0 +1,4 @@ +- feat: added code mode to mcp +- feat: added health monitoring to mcp +- feat: added responses format tool execution support to mcp +- refactor: governance plugin refactored for extensibility and optimization diff --git a/transports/config.schema.json b/transports/config.schema.json index fa7a1a9e2..2097822a9 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -395,6 +395,9 @@ "$ref": "#/$defs/mcp_client_config" }, "description": "MCP client configurations" + }, + "tool_manager_config": { + "$ref": "#/$defs/mcp_tool_manager_config" } }, "additionalProperties": false @@ -1760,6 +1763,23 @@ } ] }, + "mcp_tool_manager_config": { + "type": "object", + "properties": { + "tool_execution_timeout": { + "type": "integer", + "description": "Tool execution timeout in seconds", + "minimum": 1, + "default": 30 + }, + "max_agent_depth": { + "type": "integer", + "description": "Max agent depth", + "minimum": 1, + "default": 10 + } + } + }, "weaviate_config": { "type": "object", "description": "Weaviate configuration for vector store", diff --git a/transports/go.mod b/transports/go.mod index 6b81a6b02..a72a52efd 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -8,14 +8,15 @@ require ( github.com/fasthttp/router v1.5.4 github.com/fasthttp/websocket v1.5.12 github.com/google/uuid v1.6.0 + github.com/mark3labs/mcp-go v0.43.2 github.com/maximhq/bifrost/core v1.2.42 - github.com/maximhq/bifrost/framework v1.1.52 - github.com/maximhq/bifrost/plugins/governance v1.3.53 - github.com/maximhq/bifrost/plugins/logging v1.3.53 - github.com/maximhq/bifrost/plugins/maxim v1.4.53 - github.com/maximhq/bifrost/plugins/otel v1.0.52 - github.com/maximhq/bifrost/plugins/semanticcache v1.3.52 - github.com/maximhq/bifrost/plugins/telemetry v1.3.52 + github.com/maximhq/bifrost/framework v1.1.50 + github.com/maximhq/bifrost/plugins/governance v1.3.51 + github.com/maximhq/bifrost/plugins/logging v1.3.51 + github.com/maximhq/bifrost/plugins/maxim v1.4.51 + github.com/maximhq/bifrost/plugins/otel v1.0.50 + github.com/maximhq/bifrost/plugins/semanticcache v1.3.50 + github.com/maximhq/bifrost/plugins/telemetry v1.3.50 github.com/prometheus/client_golang v1.23.0 github.com/stretchr/testify v1.11.1 github.com/valyala/fasthttp v1.68.0 @@ -89,7 +90,6 @@ require ( github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/mailru/easyjson v0.9.1 // indirect - github.com/mark3labs/mcp-go v0.43.2 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect diff --git a/transports/go.sum b/transports/go.sum index 8b16d1ec7..c9f0de5aa 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -180,22 +180,22 @@ github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuE github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/maximhq/bifrost/core v1.2.42 h1:0G5TD4sZWlT8CwteobFpXmnALGzFQ6lsrDAQl8tr7/k= github.com/maximhq/bifrost/core v1.2.42/go.mod h1:1msCedjIgC8d9TNJyB1z7s+348vh2Bd0u66qAPgpZoA= -github.com/maximhq/bifrost/framework v1.1.52 h1:n36FUjcnXoNQaVVYdkMcBMf6VnthloWaq1rFomdqVVA= -github.com/maximhq/bifrost/framework v1.1.52/go.mod h1:Tb8mLcg0o03s+4m8lizSwN2h9K8AKdRD8e3lrjBk6zo= -github.com/maximhq/bifrost/plugins/governance v1.3.53 h1:SOHu3BK1+v62hT+YMf7uALXvZHEBrPGDiXmGf4hhyrU= -github.com/maximhq/bifrost/plugins/governance v1.3.53/go.mod h1:DDxkv0B+NhrB0G2OPGl6BfJ/IsadlRuZJFIyfhLLei0= -github.com/maximhq/bifrost/plugins/logging v1.3.53 h1:rpqkDutVtVCX/CrBoqeESnE2Ls9Id2p0zrRdAzdsWyc= -github.com/maximhq/bifrost/plugins/logging v1.3.53/go.mod h1:I7MLX6kklqpsu+Udo39AgTa73wpc3mN5xjokdsK8rIw= -github.com/maximhq/bifrost/plugins/maxim v1.4.53 h1:WJxKd3ggIUBB4QlmtYyDTeNUgttEABbkURCK9zHi/nE= -github.com/maximhq/bifrost/plugins/maxim v1.4.53/go.mod h1:tjleeSB6yUeQvvjLR9DlAg15e6OdbnnzNC0bXgw+MN8= +github.com/maximhq/bifrost/framework v1.1.50 h1:dy4awTo1b51hXll8Uk2HdZTZESXGJJCYcsBUOmzTLqU= +github.com/maximhq/bifrost/framework v1.1.50/go.mod h1:opfHztmWhGkg0QSQGvZsU1wXEpMYLpBl01J6LsKECng= +github.com/maximhq/bifrost/plugins/governance v1.3.51 h1:ng1F+N9nKD95qvrpld+aQVjn0ZXzT2RvDOOKfKUmVKk= +github.com/maximhq/bifrost/plugins/governance v1.3.51/go.mod h1:1JFy0SOz3wrann3gwj3AdUwEPAu5ADGeUPakH/fSHSM= +github.com/maximhq/bifrost/plugins/logging v1.3.51 h1:WC7E+xB54aBp1yHiw1ZhomqPn2AMEp/AQkVHwKxtOc8= +github.com/maximhq/bifrost/plugins/logging v1.3.51/go.mod h1:CaPFNbzK+1HEVERtxEY+BnOYRbkov3Gv+YDtrSL4HWQ= +github.com/maximhq/bifrost/plugins/maxim v1.4.51 h1:B3E+TvV4dFjVToH2HQrsMw3tMzCX2+DFvpMjfeeRm8M= +github.com/maximhq/bifrost/plugins/maxim v1.4.51/go.mod h1:JtPxw49OzA8XqEmSlfuRc86cdE/+QlebrQ3ps7izozU= github.com/maximhq/bifrost/plugins/mocker v1.3.52 h1:DgujQBv7IilNEC3K70tu/7unaE89CzNpadwaLoFiCRE= github.com/maximhq/bifrost/plugins/mocker v1.3.52/go.mod h1:wQlAj8+dbiPB4cQHdsQSL4tuVEVN4Iwz5mw0oPY7E/M= -github.com/maximhq/bifrost/plugins/otel v1.0.52 h1:AyKfBzeqJCvhbK3ZqeRn1GOXMm40WEZSar087XUmwbw= -github.com/maximhq/bifrost/plugins/otel v1.0.52/go.mod h1:ZZS+BeX/R4MXpfzt+f49Bx+Z552Kh+i/30WcQcsN1d0= -github.com/maximhq/bifrost/plugins/semanticcache v1.3.52 h1:GejgTEO2MoZ0WGZ9E9VQbTRWjnwO7IYoXl+aCJwFuj8= -github.com/maximhq/bifrost/plugins/semanticcache v1.3.52/go.mod h1:sZAmbwXmdF3D7DHYfQgBUFwz5N21QFo49ncZY8JzqiA= -github.com/maximhq/bifrost/plugins/telemetry v1.3.52 h1:ENwJ+sius5mhI5CZ2jfuerld/xMxJlGLwMWcWFXw3Vk= -github.com/maximhq/bifrost/plugins/telemetry v1.3.52/go.mod h1:Z0cVc2oHApRaX7zQFvLrrlW06XL263bc96V5rFfQjV8= +github.com/maximhq/bifrost/plugins/otel v1.0.50 h1:kd2mD4sNbQVZAd2lCaYHRo7WMlH5Heswuw+jUyM+dos= +github.com/maximhq/bifrost/plugins/otel v1.0.50/go.mod h1:TR8nrgglI3FZXMM+aedOGGNP1dwS/BaWce6rr0S9OKo= +github.com/maximhq/bifrost/plugins/semanticcache v1.3.50 h1:szdnvWtFXZ2e9M9ODts5KRCc5DsOhzIkp81o2JBJtig= +github.com/maximhq/bifrost/plugins/semanticcache v1.3.50/go.mod h1:Pl7wG7UdAg+uiNu0BacKckp6IRlWO3C4ykP0wrCAslE= +github.com/maximhq/bifrost/plugins/telemetry v1.3.50 h1:wFOFfTkeeJHliYssS93DKvX74c+4qJANeiQN4XVw2/o= +github.com/maximhq/bifrost/plugins/telemetry v1.3.50/go.mod h1:PAV6muMHZ+KzMVvww0/ip/A157ffQgnhem+pv6sXGLQ= github.com/maximhq/maxim-go v0.1.14 h1:NQgpf3aRoD2Kq1GAqeSrLn3rQresn1H6mPP3JJ85qhA= github.com/maximhq/maxim-go v0.1.14/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= diff --git a/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx b/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx index 36ef7bf5a..fe98931c3 100644 --- a/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx +++ b/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx @@ -5,13 +5,12 @@ import { Button } from "@/components/ui/button"; import { useGetCoreConfigQuery } from "@/lib/store"; import { Copy, InfoIcon, KeyRound } from "lucide-react"; import Link from "next/link"; -import { useMemo, useState } from "react"; +import { useMemo } from "react"; import { toast } from "sonner"; import ContactUsView from "../views/contactUsView"; export default function APIKeysView() { const { data: bifrostConfig, isLoading } = useGetCoreConfigQuery({ fromDB: true }); - const [isTokenVisible, setIsTokenVisible] = useState(false); const isAuthConfigure = useMemo(() => { return bifrostConfig?.auth_config?.is_enabled; }, [bifrostConfig]); diff --git a/ui/app/workspace/config/logging/page.tsx b/ui/app/workspace/config/logging/page.tsx index 4ccaddb31..a285754aa 100644 --- a/ui/app/workspace/config/logging/page.tsx +++ b/ui/app/workspace/config/logging/page.tsx @@ -1,12 +1,11 @@ -"use client" +"use client"; -import LoggingView from "../views/loggingView" +import LoggingView from "../views/loggingView"; export default function LoggingPage() { - return ( -
- -
- ) + return ( +
+ +
+ ); } - diff --git a/ui/app/workspace/config/mcp-gateway/page.tsx b/ui/app/workspace/config/mcp-gateway/page.tsx new file mode 100644 index 000000000..47f6865a9 --- /dev/null +++ b/ui/app/workspace/config/mcp-gateway/page.tsx @@ -0,0 +1,11 @@ +"use client"; + +import MCPGatewayView from "../views/mcpView"; + +export default function MCPGatewayPage() { + return ( +
+ +
+ ); +} diff --git a/ui/app/workspace/config/views/clientSettingsView.tsx b/ui/app/workspace/config/views/clientSettingsView.tsx index 0d90e2bdb..25d8e8d7f 100644 --- a/ui/app/workspace/config/views/clientSettingsView.tsx +++ b/ui/app/workspace/config/views/clientSettingsView.tsx @@ -21,6 +21,9 @@ const defaultConfig: CoreConfig = { max_request_body_size_mb: 100, enable_litellm_fallbacks: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function ClientSettingsView() { @@ -58,7 +61,11 @@ export default function ClientSettingsView() { const handleSave = useCallback(async () => { try { - await updateCoreConfig({ ...bifrostConfig!, client_config: localConfig }).unwrap(); + if (!bifrostConfig) { + toast.error("Configuration not loaded. Please refresh and try again."); + return; + } + await updateCoreConfig({ ...bifrostConfig, client_config: localConfig }).unwrap(); toast.success("Client settings updated successfully."); } catch (error) { toast.error(getErrorMessage(error)); @@ -66,7 +73,7 @@ export default function ClientSettingsView() { }, [bifrostConfig, localConfig, updateCoreConfig]); return ( -
+

Client Settings

diff --git a/ui/app/workspace/config/views/governanceView.tsx b/ui/app/workspace/config/views/governanceView.tsx index 6121b6692..e95441986 100644 --- a/ui/app/workspace/config/views/governanceView.tsx +++ b/ui/app/workspace/config/views/governanceView.tsx @@ -21,6 +21,9 @@ const defaultConfig: CoreConfig = { allowed_origins: [], max_request_body_size_mb: 100, enable_litellm_fallbacks: false, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function GovernanceView() { @@ -61,7 +64,7 @@ export default function GovernanceView() { }, [bifrostConfig, localConfig, updateCoreConfig]); return ( -
+

Governance

@@ -101,5 +104,3 @@ export default function GovernanceView() { const RestartWarning = () => { return
Need to restart Bifrost to apply changes.
; }; - - diff --git a/ui/app/workspace/config/views/loggingView.tsx b/ui/app/workspace/config/views/loggingView.tsx index df5beebf7..2f6f9ed2c 100644 --- a/ui/app/workspace/config/views/loggingView.tsx +++ b/ui/app/workspace/config/views/loggingView.tsx @@ -23,6 +23,9 @@ const defaultConfig: CoreConfig = { allowed_origins: [], max_request_body_size_mb: 100, enable_litellm_fallbacks: false, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function LoggingView() { @@ -126,7 +129,8 @@ export default function LoggingView() { Disable Content Logging

- When enabled, only usage metadata (latency, cost, token count, etc.) will be logged. Request/response content will not be stored. + When enabled, only usage metadata (latency, cost, token count, etc.) will be logged. Request/response content will not be + stored.

{ return
Need to restart Bifrost to apply changes.
; }; - - diff --git a/ui/app/workspace/config/views/mcpView.tsx b/ui/app/workspace/config/views/mcpView.tsx new file mode 100644 index 000000000..9fa664c1a --- /dev/null +++ b/ui/app/workspace/config/views/mcpView.tsx @@ -0,0 +1,219 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { getErrorMessage, useGetCoreConfigQuery, useUpdateCoreConfigMutation } from "@/lib/store"; +import { CoreConfig } from "@/lib/types/config"; +import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; + +const defaultConfig: CoreConfig = { + drop_excess_requests: false, + initial_pool_size: 1000, + prometheus_labels: [], + enable_logging: true, + enable_governance: true, + enforce_governance_header: false, + allow_direct_keys: false, + allowed_origins: [], + max_request_body_size_mb: 100, + enable_litellm_fallbacks: false, + disable_content_logging: false, + log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", +}; + +export default function MCPView() { + const hasSettingsUpdateAccess = useRbac(RbacResource.Settings, RbacOperation.Update); + const { data: bifrostConfig } = useGetCoreConfigQuery({ fromDB: true }); + const config = bifrostConfig?.client_config; + const [updateCoreConfig, { isLoading }] = useUpdateCoreConfigMutation(); + const [localConfig, setLocalConfig] = useState(defaultConfig); + + const [localValues, setLocalValues] = useState<{ + mcp_agent_depth: string; + mcp_tool_execution_timeout: string; + mcp_code_mode_binding_level: string; + }>({ + mcp_agent_depth: "10", + mcp_tool_execution_timeout: "30", + mcp_code_mode_binding_level: "server", + }); + + useEffect(() => { + if (bifrostConfig && config) { + setLocalConfig(config); + setLocalValues({ + mcp_agent_depth: config?.mcp_agent_depth?.toString() || "10", + mcp_tool_execution_timeout: config?.mcp_tool_execution_timeout?.toString() || "30", + mcp_code_mode_binding_level: config?.mcp_code_mode_binding_level || "server", + }); + } + }, [config, bifrostConfig]); + + const hasChanges = useMemo(() => { + if (!config) return false; + return ( + localConfig.mcp_agent_depth !== config.mcp_agent_depth || + localConfig.mcp_tool_execution_timeout !== config.mcp_tool_execution_timeout || + localConfig.mcp_code_mode_binding_level !== (config.mcp_code_mode_binding_level || "server") + ); + }, [config, localConfig]); + + const handleAgentDepthChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_agent_depth: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + setLocalConfig((prev) => ({ ...prev, mcp_agent_depth: numValue })); + } + }, []); + + const handleToolExecutionTimeoutChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_tool_execution_timeout: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + setLocalConfig((prev) => ({ ...prev, mcp_tool_execution_timeout: numValue })); + } + }, []); + + const handleCodeModeBindingLevelChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_code_mode_binding_level: value })); + if (value === "server" || value === "tool") { + setLocalConfig((prev) => ({ ...prev, mcp_code_mode_binding_level: value })); + } + }, []); + + const handleSave = useCallback(async () => { + try { + const agentDepth = Number.parseInt(localValues.mcp_agent_depth); + const toolTimeout = Number.parseInt(localValues.mcp_tool_execution_timeout); + + if (isNaN(agentDepth) || agentDepth <= 0) { + toast.error("Max agent depth must be a positive number."); + return; + } + + if (isNaN(toolTimeout) || toolTimeout <= 0) { + toast.error("Tool execution timeout must be a positive number."); + return; + } + + if (!bifrostConfig) { + toast.error("Configuration not loaded. Please refresh and try again."); + return; + } + await updateCoreConfig({ ...bifrostConfig, client_config: localConfig }).unwrap(); + toast.success("MCP settings updated successfully."); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }, [bifrostConfig, localConfig, localValues, updateCoreConfig]); + + return ( +
+
+
+

MCP Settings

+

Configure MCP (Model Context Protocol) agent and tool settings.

+
+ +
+
+ {/* Max Agent Depth */} +
+
+ +

Maximum depth for MCP agent execution.

+
+ handleAgentDepthChange(e.target.value)} + min="1" + /> +
+ + {/* Tool Execution Timeout */} +
+
+ +

Maximum time in seconds for tool execution.

+
+ handleToolExecutionTimeoutChange(e.target.value)} + min="1" + /> +
+ + {/* Code Mode Binding Level */} +
+
+ +

+ How tools are exposed in the VFS: server-level (all tools per server) or tool-level (individual tools). +

+
+ + + {/* Visual Example */} +
+

VFS Structure:

+ + {localValues.mcp_code_mode_binding_level === "server" ? ( +
+
+
servers/
+
├─ calculator.d.ts
+
├─ youtube.d.ts
+
└─ weather.d.ts
+
+

All tools per server in a single .d.ts file

+
+ ) : ( +
+
+
servers/
+
├─ calculator/
+
├─ add.d.ts
+
└─ subtract.d.ts
+
├─ youtube/
+
├─ GET_CHANNELS.d.ts
+
└─ SEARCH_VIDEOS.d.ts
+
└─ weather/
+
└─ get_forecast.d.ts
+
+

Individual .d.ts file for each tool

+
+ )} +
+
+
+
+ ); +} diff --git a/ui/app/workspace/config/views/observabilityView.tsx b/ui/app/workspace/config/views/observabilityView.tsx index 9f680b933..126da2a5d 100644 --- a/ui/app/workspace/config/views/observabilityView.tsx +++ b/ui/app/workspace/config/views/observabilityView.tsx @@ -24,6 +24,9 @@ const defaultConfig: CoreConfig = { enable_litellm_fallbacks: false, disable_content_logging: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function ObservabilityView() { diff --git a/ui/app/workspace/config/views/performanceTuningView.tsx b/ui/app/workspace/config/views/performanceTuningView.tsx index 2778c6e07..e6b605553 100644 --- a/ui/app/workspace/config/views/performanceTuningView.tsx +++ b/ui/app/workspace/config/views/performanceTuningView.tsx @@ -23,6 +23,9 @@ const defaultConfig: CoreConfig = { enable_litellm_fallbacks: false, disable_content_logging: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function PerformanceTuningView() { @@ -91,7 +94,11 @@ export default function PerformanceTuningView() { return; } - await updateCoreConfig({ ...bifrostConfig!, client_config: localConfig }).unwrap(); + if (!bifrostConfig) { + toast.error("Configuration not loaded. Please refresh and try again."); + return; + } + await updateCoreConfig({ ...bifrostConfig, client_config: localConfig }).unwrap(); toast.success("Performance settings updated successfully."); } catch (error) { toast.error(getErrorMessage(error)); @@ -99,7 +106,7 @@ export default function PerformanceTuningView() { }, [bifrostConfig, localConfig, localValues, updateCoreConfig]); return ( -
+

Performance Tuning

diff --git a/ui/app/workspace/config/views/securityView.tsx b/ui/app/workspace/config/views/securityView.tsx index ace6c6d02..76b0b6163 100644 --- a/ui/app/workspace/config/views/securityView.tsx +++ b/ui/app/workspace/config/views/securityView.tsx @@ -31,6 +31,9 @@ const defaultConfig: CoreConfig = { max_request_body_size_mb: 100, enable_litellm_fallbacks: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function SecurityView() { diff --git a/ui/app/workspace/logs/views/columns.tsx b/ui/app/workspace/logs/views/columns.tsx index c69d275cb..e7fa0d1e4 100644 --- a/ui/app/workspace/logs/views/columns.tsx +++ b/ui/app/workspace/logs/views/columns.tsx @@ -1,13 +1,13 @@ -"use client" +"use client"; -import { Badge } from "@/components/ui/badge" -import { Button } from "@/components/ui/button" -import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons" -import { ProviderName, RequestTypeColors, RequestTypeLabels, Status, StatusColors } from "@/lib/constants/logs" -import { LogEntry, ResponsesMessageContentBlock } from "@/lib/types/logs" -import { ColumnDef } from "@tanstack/react-table" -import { ArrowUpDown, Trash2 } from "lucide-react" -import moment from "moment" +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons"; +import { ProviderName, RequestTypeColors, RequestTypeLabels, Status, StatusColors } from "@/lib/constants/logs"; +import { LogEntry, ResponsesMessageContentBlock } from "@/lib/types/logs"; +import { ColumnDef } from "@tanstack/react-table"; +import { ArrowUpDown, Trash2 } from "lucide-react"; +import moment from "moment"; function getMessage(log?: LogEntry) { if (log?.input_history && log.input_history.length > 0) { @@ -26,7 +26,8 @@ function getMessage(log?: LogEntry) { } return lastTextContentBlock; } else if (log?.responses_input_history && log.responses_input_history.length > 0) { - let lastMessageContent = log.responses_input_history[log.responses_input_history.length - 1].content; + let lastMessage = log.responses_input_history[log.responses_input_history.length - 1]; + let lastMessageContent = lastMessage.content; if (typeof lastMessageContent === "string") { return lastMessageContent; } @@ -36,7 +37,18 @@ function getMessage(log?: LogEntry) { lastTextContentBlock = block.text; } } - return lastTextContentBlock; + // If no content found in content field, check output field for Responses API + if (!lastTextContentBlock && lastMessage.output) { + // Handle output field - it could be a string, an array of content blocks, or a computer tool call output data + if (typeof lastMessage.output === "string") { + return lastMessage.output; + } else if (Array.isArray(lastMessage.output)) { + return lastMessage.output.map((block) => block.text).join("\n"); + } else if (lastMessage.output.type && lastMessage.output.type === "computer_screenshot") { + return lastMessage.output.image_url; + } + } + return lastTextContentBlock ?? ""; } else if (log?.speech_input) { return log.speech_input.input; } else if (log?.transcription_input) { @@ -174,12 +186,12 @@ export const createColumns = (onDelete: (log: LogEntry) => void, hasDeleteAccess { id: "actions", cell: ({ row }) => { - const log = row.original + const log = row.original; return ( - ) + ); }, }, -] +]; diff --git a/ui/app/workspace/logs/views/logChatMessageView.tsx b/ui/app/workspace/logs/views/logChatMessageView.tsx index 5b0c821cf..f99484ddf 100644 --- a/ui/app/workspace/logs/views/logChatMessageView.tsx +++ b/ui/app/workspace/logs/views/logChatMessageView.tsx @@ -111,14 +111,14 @@ export default function LogChatMessageView({ message, audioFormat }: LogChatMess options={{ scrollBeyondLastLine: false, collapsibleBlocks: true, lineNumbers: "off", alwaysConsumeMouseWheel: false }} /> ) : ( -
{message.refusal}
+
{message.refusal}
)}
)} {/* Handle content */} {message.content && ( -
+
{typeof message.content === "string" ? ( <> {isJson(message.content) ? ( @@ -133,7 +133,7 @@ export default function LogChatMessageView({ message, audioFormat }: LogChatMess options={{ scrollBeyondLastLine: false, collapsibleBlocks: true, lineNumbers: "off", alwaysConsumeMouseWheel: false }} /> ) : ( -
{message.content}
+
{message.content}
)} ) : ( diff --git a/ui/app/workspace/logs/views/logDetailsSheet.tsx b/ui/app/workspace/logs/views/logDetailsSheet.tsx index 57f673e5f..e8efb07ca 100644 --- a/ui/app/workspace/logs/views/logDetailsSheet.tsx +++ b/ui/app/workspace/logs/views/logDetailsSheet.tsx @@ -1,14 +1,27 @@ "use client"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alertDialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { DottedSeparator } from "@/components/ui/separator"; import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons"; import { RequestTypeColors, RequestTypeLabels, Status, StatusColors } from "@/lib/constants/logs"; import { LogEntry } from "@/lib/types/logs"; -import { DollarSign, FileText, Timer, Trash2 } from "lucide-react"; +import { Clipboard, DollarSign, FileText, Timer, Trash2 } from "lucide-react"; import moment from "moment"; +import { toast } from "sonner"; import { CodeEditor } from "./codeEditor"; import LogChatMessageView from "./logChatMessageView"; import LogEntryDetailsView from "./logEntryDetailsView"; @@ -34,6 +47,121 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet } catch (ignored) {} } + const copyRequestBody = async () => { + try { + // Check if request is for responses, chat, speech, text completion, or embedding (exclude transcriptions) + const object = log.object?.toLowerCase() || ""; + const isChat = object === "chat_completion" || object === "chat_completion_stream"; + const isResponses = object === "responses" || object === "responses_stream"; + const isSpeech = object === "speech" || object === "speech_stream"; + const isTextCompletion = object === "text_completion" || object === "text_completion_stream"; + const isEmbedding = object === "embedding"; + const isTranscription = object === "transcription" || object === "transcription_stream"; + + // Skip if transcription + if (isTranscription) { + toast.error("Copy request body is not available for transcription requests"); + return; + } + + // Skip if not a supported request type + if (!isChat && !isResponses && !isSpeech && !isTextCompletion && !isEmbedding) { + toast.error("Copy request body is only available for chat, responses, speech, text completion, and embedding requests"); + return; + } + + // Helper function to extract text content from ChatMessage + const extractTextFromMessage = (message: any): string => { + if (!message || !message.content) { + return ""; + } + if (typeof message.content === "string") { + return message.content; + } + if (Array.isArray(message.content)) { + return message.content + .filter((block: any) => block && block.type === "text" && block.text) + .map((block: any) => block.text || "") + .join(""); + } + return ""; + }; + + // Helper function to extract texts from ChatMessage content blocks (for embeddings) + const extractTextsFromMessage = (message: any): string[] => { + if (!message || !message.content) { + return []; + } + if (typeof message.content === "string") { + return message.content ? [message.content] : []; + } + if (Array.isArray(message.content)) { + return message.content.filter((block: any) => block && block.type === "text" && block.text).map((block: any) => block.text); + } + return []; + }; + + // Build request body following OpenAI schema + const requestBody: any = { + model: log.provider && log.model ? `${log.provider}/${log.model}` : log.model || "", + }; + + // Add messages/input/prompt based on request type + if (isChat && log.input_history && log.input_history.length > 0) { + requestBody.messages = log.input_history; + } else if (isResponses && log.responses_input_history && log.responses_input_history.length > 0) { + requestBody.input = log.responses_input_history; + } else if (isSpeech && log.speech_input) { + requestBody.input = log.speech_input.input; + } else if (isTextCompletion && log.input_history && log.input_history.length > 0) { + // For text completions, extract prompt from input_history + const firstMessage = log.input_history[0]; + const prompt = extractTextFromMessage(firstMessage); + if (prompt) { + requestBody.prompt = prompt; + } + } else if (isEmbedding && log.input_history && log.input_history.length > 0) { + // For embeddings, extract all texts from input_history + const texts: string[] = []; + for (const message of log.input_history) { + const messageTexts = extractTextsFromMessage(message); + texts.push(...messageTexts); + } + if (texts.length > 0) { + // Use single string if only one text, otherwise use array + requestBody.input = texts.length === 1 ? texts[0] : texts; + } + } + + // Add params (excluding tools and instructions as they're handled separately in OpenAI schema) + if (log.params) { + const paramsCopy = { ...log.params }; + // Remove tools and instructions from params as they're typically top-level in OpenAI schema + // Keep all other params (temperature, max_tokens, voice, etc.) + delete paramsCopy.tools; + delete paramsCopy.instructions; + + // Merge remaining params into request body + Object.assign(requestBody, paramsCopy); + } + + // Add tools if they exist (for chat and responses) - OpenAI schema has tools at top level + if ((isChat || isResponses) && log.params?.tools && Array.isArray(log.params.tools) && log.params.tools.length > 0) { + requestBody.tools = log.params.tools; + } + + // Add instructions if they exist (for responses) - OpenAI schema has instructions at top level + if (isResponses && log.params?.instructions) { + requestBody.instructions = log.params.instructions; + } + + const requestBodyJson = JSON.stringify(requestBody, null, 2); + navigator.clipboard.writeText(requestBodyJson); + toast.success("Request body copied to clipboard"); + } catch (error) { + toast.error("Failed to copy request body"); + } + }; // Extract audio format from request params // Format can be in params.audio?.format or params.extra_params?.audio?.format const audioFormat = (log.params as any)?.audio?.format || (log.params as any)?.extra_params?.audio?.format || undefined; @@ -50,16 +178,44 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet
- +
+ + + + + + +

Copy request body JSON

+
+
+
+ + + + + + + Are you sure you want to delete this log? + This action cannot be undone. This will permanently delete the log entry. + + + Cancel + { + handleDelete(log); + onOpenChange(false); + }} + > + Delete + + + + +
@@ -338,7 +494,7 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet
{toolsParameter && (
-
Tools
+
Tools ({log.params?.tools?.length || 0})
= ({ open, onClose, onSaved }) => { } }, [open]); - const handleChange = (field: keyof CreateMCPClientRequest, value: string | string[] | MCPConnectionType | MCPStdioConfig | undefined) => { + const handleChange = ( + field: keyof CreateMCPClientRequest, + value: string | string[] | boolean | MCPConnectionType | MCPStdioConfig | undefined, + ) => { setForm((prev) => ({ ...prev, [field]: value })); }; @@ -95,10 +100,13 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { const validator = new Validator([ // Name validation - Validator.required(form.name?.trim(), "Client name is required"), - Validator.pattern(form.name || "", /^[a-zA-Z0-9-_]+$/, "Client name can only contain letters, numbers, hyphens and underscores"), - Validator.minLength(form.name || "", 3, "Client name must be at least 3 characters"), - Validator.maxLength(form.name || "", 50, "Client name cannot exceed 50 characters"), + Validator.required(form.name?.trim(), "Server name is required"), + Validator.pattern(form.name || "", /^[a-zA-Z0-9_]+$/, "Server name can only contain letters, numbers, and underscores"), + Validator.custom(!(form.name || "").includes("-"), "Server name cannot contain hyphens"), + Validator.custom(!(form.name || "").includes(" "), "Server name cannot contain spaces"), + Validator.custom((form.name || "").length === 0 || !/^[0-9]/.test(form.name || ""), "Server name cannot start with a number"), + Validator.minLength(form.name || "", 3, "Server name must be at least 3 characters"), + Validator.maxLength(form.name || "", 50, "Server name cannot exceed 50 characters"), // Connection type specific validation ...(form.connection_type === "http" || form.connection_type === "sse" @@ -156,7 +164,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { setIsLoading(false); toast({ title: "Success", - description: "Client created", + description: "Server created", }); onSaved(); onClose(); @@ -170,7 +178,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { - New MCP Client + New MCP Server
@@ -178,7 +186,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { ) => handleChange("name", e.target.value)} - placeholder="Client name" + placeholder="Server name" maxLength={50} />
@@ -197,6 +205,15 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => {
+
+ + handleChange("is_code_mode_client", checked)} + /> +
+ {(form.connection_type === "http" || form.connection_type === "sse") && ( <>
@@ -275,7 +292,11 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { - diff --git a/ui/app/workspace/mcp-clients/views/mcpClientSheet.tsx b/ui/app/workspace/mcp-gateway/views/mcpClientSheet.tsx similarity index 62% rename from ui/app/workspace/mcp-clients/views/mcpClientSheet.tsx rename to ui/app/workspace/mcp-gateway/views/mcpClientSheet.tsx index 618e45e88..52dcadb9c 100644 --- a/ui/app/workspace/mcp-clients/views/mcpClientSheet.tsx +++ b/ui/app/workspace/mcp-gateway/views/mcpClientSheet.tsx @@ -8,6 +8,7 @@ import { HeadersTable } from "@/components/ui/headersTable"; import { Input } from "@/components/ui/input"; import { Sheet, SheetContent, SheetDescription, SheetHeader, SheetTitle } from "@/components/ui/sheet"; import { Switch } from "@/components/ui/switch"; +import { TriStateCheckbox } from "@/components/ui/tristateCheckbox"; import { useToast } from "@/hooks/use-toast"; import { MCP_STATUS_COLORS } from "@/lib/constants/config"; import { getErrorMessage, useUpdateMCPClientMutation } from "@/lib/store"; @@ -34,8 +35,10 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: mode: "onBlur", defaultValues: { name: mcpClient.config.name, + is_code_mode_client: mcpClient.config.is_code_mode_client || false, headers: mcpClient.config.headers, tools_to_execute: mcpClient.config.tools_to_execute || [], + tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], }, }); @@ -43,8 +46,10 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: useEffect(() => { form.reset({ name: mcpClient.config.name, + is_code_mode_client: mcpClient.config.is_code_mode_client || false, headers: mcpClient.config.headers, tools_to_execute: mcpClient.config.tools_to_execute || [], + tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], }); }, [form, mcpClient]); @@ -54,8 +59,10 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: id: mcpClient.config.id, data: { name: data.name, + is_code_mode_client: data.is_code_mode_client, headers: data.headers, tools_to_execute: data.tools_to_execute, + tools_to_auto_execute: data.tools_to_auto_execute, }, }).unwrap(); @@ -106,6 +113,69 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: } form.setValue("tools_to_execute", newTools, { shouldDirty: true }); + + // If tool is being removed from tools_to_execute, also remove it from tools_to_auto_execute + if (!checked) { + const currentAutoExecute = form.getValues("tools_to_auto_execute") || []; + if (currentAutoExecute.includes(toolName) || currentAutoExecute.includes("*")) { + const newAutoExecute = currentAutoExecute.filter((tool) => tool !== toolName); + // If we had "*" and removed a tool, we need to recalculate + if (currentAutoExecute.includes("*")) { + // If all tools mode, keep "*" only if tool is still in tools_to_execute + if (newTools.includes("*")) { + form.setValue("tools_to_auto_execute", ["*"], { shouldDirty: true }); + } else { + // Switch to explicit list - when in wildcard mode, all remaining tools should be auto-execute + form.setValue("tools_to_auto_execute", newTools, { shouldDirty: true }); + } + } else { + form.setValue("tools_to_auto_execute", newAutoExecute, { shouldDirty: true }); + } + } + } + }; + + const handleAutoExecuteToggle = (toolName: string, checked: boolean) => { + const currentAutoExecute = form.getValues("tools_to_auto_execute") || []; + const currentTools = form.getValues("tools_to_execute") || []; + const allToolNames = mcpClient.tools?.map((tool) => tool.name) || []; + + // Check if we're in "all tools" mode (wildcard) + const isAllToolsMode = currentTools.includes("*"); + const isAllAutoExecuteMode = currentAutoExecute.includes("*"); + + let newAutoExecute: string[]; + + if (isAllAutoExecuteMode) { + if (checked) { + // Already all selected, keep wildcard + newAutoExecute = ["*"]; + } else { + // Unchecking a tool when all are selected - switch to explicit list without this tool + if (isAllToolsMode) { + newAutoExecute = allToolNames.filter((name) => name !== toolName); + } else { + newAutoExecute = currentTools.filter((name) => name !== toolName); + } + } + } else { + // We're in explicit tool selection mode + if (checked) { + // Add tool to selection + newAutoExecute = currentAutoExecute.includes(toolName) ? currentAutoExecute : [...currentAutoExecute, toolName]; + + // If we now have all allowed tools selected, switch to wildcard mode + const allowedTools = isAllToolsMode ? allToolNames : currentTools; + if (newAutoExecute.length === allowedTools.length && allowedTools.every((tool) => newAutoExecute.includes(tool))) { + newAutoExecute = ["*"]; + } + } else { + // Remove tool from selection + newAutoExecute = currentAutoExecute.filter((tool) => tool !== toolName); + } + } + + form.setValue("tools_to_auto_execute", newAutoExecute, { shouldDirty: true }); }; return ( @@ -113,14 +183,14 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }:
- -
+ +
{mcpClient.config.name} {mcpClient.state} - MCP client configuration and available tools + MCP server configuration and available tools
- Manage clients that can connect to the MCP Tools endpoint. + Manage servers that can connect to the MCP Tools endpoint.
@@ -144,7 +144,10 @@ export default function MCPClientsTable({ mcpClients }: MCPClientsTableProps) { Name Connection Type + Code Mode Connection Info + Enabled Tools + Auto-execute Tools State @@ -152,56 +155,101 @@ export default function MCPClientsTable({ mcpClients }: MCPClientsTableProps) { {clients.length === 0 && ( - + No clients found. )} - {clients.map((c: MCPClient) => ( - handleRowClick(c)}> - {c.config.name} - {getConnectionTypeDisplay(c.config.connection_type)} - {getConnectionDisplay(c)} - - {c.state} - - e.stopPropagation()}> - - - - - - - - - Remove MCP Client - - Are you sure you want to remove MCP client {c.config.name}? You will need to reconnect the client to continue - using it. - - - - Cancel - handleDelete(c)}>Delete - - - - - - ))} + + + {c.state == "connected" ? ( + <> + {autoExecuteToolsCount}/{c.tools?.length} + + ) : ( + "-" + )} + + + {c.state} + + e.stopPropagation()}> + + + + + + + + + Remove MCP Server + + Are you sure you want to remove MCP server {c.config.name}? You will need to reconnect the server to continue + using it. + + + + Cancel + handleDelete(c)}>Delete + + + + + + ); + })}
diff --git a/ui/components/sidebar.tsx b/ui/components/sidebar.tsx index 6e0d626ba..c4393a441 100644 --- a/ui/components/sidebar.tsx +++ b/ui/components/sidebar.tsx @@ -316,7 +316,7 @@ export default function AppSidebar() { const hasLogsAccess = useRbac(RbacResource.Logs, RbacOperation.View); const hasObservabilityAccess = useRbac(RbacResource.Observability, RbacOperation.View); const hasModelProvidersAccess = useRbac(RbacResource.ModelProvider, RbacOperation.View); - const hasMCPToolsAccess = useRbac(RbacResource.MCPGateway, RbacOperation.View); + const hasMCPGatewayAccess = useRbac(RbacResource.MCPGateway, RbacOperation.View); const hasPluginsAccess = useRbac(RbacResource.Plugins, RbacOperation.View); const hasUserProvisioningAccess = useRbac(RbacResource.UserProvisioning, RbacOperation.View); const hasAuditLogsAccess = useRbac(RbacResource.AuditLogs, RbacOperation.View); @@ -369,11 +369,11 @@ export default function AppSidebar() { hasAccess: hasModelProvidersAccess, }, { - title: "MCP Tools", - url: "/workspace/mcp-clients", + title: "MCP Gateway", + url: "/workspace/mcp-gateway", icon: MCPIcon, description: "MCP configuration", - hasAccess: hasMCPToolsAccess, + hasAccess: hasMCPGatewayAccess, }, { title: "Plugins", @@ -486,6 +486,13 @@ export default function AppSidebar() { description: "Client configuration settings", hasAccess: hasSettingsAccess, }, + { + title: "MCP Gateway", + url: "/workspace/config/mcp-gateway", + icon: MCPIcon, + description: "MCP gateway configuration", + hasAccess: hasMCPGatewayAccess, + }, { title: "Pricing Config", url: "/workspace/config/pricing-config", @@ -715,7 +722,7 @@ export default function AppSidebar() { isExpanded={expandedItems.has(item.title)} onToggle={() => toggleItem(item.title)} pathname={pathname} - router={router} + router={router} /> ); })} diff --git a/ui/components/ui/sheet.tsx b/ui/components/ui/sheet.tsx index 6b68509f7..63b6fb33c 100644 --- a/ui/components/ui/sheet.tsx +++ b/ui/components/ui/sheet.tsx @@ -72,11 +72,11 @@ function SheetContent({ className={cn( "bg-background data-[state=open]:animate-in data-[state=closed]:animate-out fixed z-50 flex flex-col shadow-lg transition-all ease-in-out data-[state=closed]:duration-100 data-[state=open]:duration-100", side === "right" && - "data-[state=closed]:slide-out-to-right data-[state=open]:slide-in-from-right top-2 bottom-2 right-0 h-auto w-3/4 border-l rounded-l-lg", + "data-[state=closed]:slide-out-to-right data-[state=open]:slide-in-from-right top-2 right-0 bottom-2 h-auto w-3/4 rounded-l-lg border-l", side === "right" && (!expandable || !expanded) && "sm:max-w-2xl", side === "right" && expandable && expanded && "sm:max-w-5xl", side === "left" && - "data-[state=closed]:slide-out-to-left data-[state=open]:slide-in-from-left top-2 bottom-2 left-0 h-auto w-3/4 border-r rounded-r-lg sm:max-w-sm", + "data-[state=closed]:slide-out-to-left data-[state=open]:slide-in-from-left top-2 bottom-2 left-0 h-auto w-3/4 rounded-r-lg border-r sm:max-w-sm", side === "top" && "data-[state=closed]:slide-out-to-top data-[state=open]:slide-in-from-top inset-x-0 top-0 h-auto border-b", side === "bottom" && "data-[state=closed]:slide-out-to-bottom data-[state=open]:slide-in-from-bottom inset-x-0 bottom-0 h-auto border-t", @@ -91,36 +91,33 @@ function SheetContent({ ); } -function SheetHeader({ className, children, ...props }: React.ComponentProps<"div">) { +function SheetHeader({ + className, + children, + showCloseButton = true, + ...props +}: React.ComponentProps<"div"> & { showCloseButton?: boolean }) { const sheetContext = useSheetContext(); return ( -
+
{sheetContext?.expandable && sheetContext?.side === "right" && ( )} -
- {children} -
- - - Close - +
{children}
+ {showCloseButton && ( + + + Close + + )}
); } @@ -138,4 +135,3 @@ function SheetDescription({ className, ...props }: React.ComponentProps void; + + /** Optional label to render to the right of the checkbox */ + label?: React.ReactNode; + + /** Optional disabled state */ + disabled?: boolean; + + /** Extra tailwind classes for the wrapper */ + className?: string; + + /** Accessible name for icon-only checkbox (e.g. when label is rendered elsewhere) */ + ariaLabel?: string; +} + +export const TriStateCheckbox: React.FC = ({ + allIds, + selectedIds, + onChange, + label, + disabled = false, + className = "", + ariaLabel, +}) => { + const state: TriState = useMemo(() => { + if (!allIds.length) return "none"; + + const selectedSet = new Set(selectedIds); + const selectedCount = allIds.filter((id) => selectedSet.has(id)).length; + + if (selectedCount === 0) return "none"; + if (selectedCount === allIds.length) return "all"; + return "some"; + }, [allIds, selectedIds]); + + const handleClick = () => { + if (disabled) return; + + let nextSelected: string[]; + + switch (state) { + case "all": + // clear all + nextSelected = []; + break; + case "some": + case "none": + default: + // select all + nextSelected = [...allIds]; + break; + } + + onChange(nextSelected); + }; + + const ariaChecked: boolean | "mixed" = state === "all" ? true : state === "none" ? false : "mixed"; + + const isChecked = state === "all"; + const isIndeterminate = state === "some"; + + return ( + + ); +}; diff --git a/ui/lib/constants/logs.ts b/ui/lib/constants/logs.ts index 98a4b6390..2b60d8485 100644 --- a/ui/lib/constants/logs.ts +++ b/ui/lib/constants/logs.ts @@ -119,8 +119,6 @@ export const RequestTypeLabels = { file_retrieve: "File Retrieve", file_delete: "File Delete", file_content: "File Content", - - } as const; export const RequestTypeColors = { @@ -155,7 +153,7 @@ export const RequestTypeColors = { batch_retrieve: "bg-red-100 text-red-800", batch_cancel: "bg-yellow-100 text-yellow-800", batch_results: "bg-purple-100 text-purple-800", - + file_upload: "bg-pink-100 text-pink-800", file_list: "bg-lime-100 text-lime-800", file_retrieve: "bg-orange-100 text-orange-800", diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index 3c1f082be..c944d66dc 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -325,6 +325,9 @@ export interface CoreConfig { allowed_origins: string[]; max_request_body_size_mb: number; enable_litellm_fallbacks: boolean; + mcp_agent_depth: number; + mcp_tool_execution_timeout: number; + mcp_code_mode_binding_level?: string; } // Semantic cache configuration types diff --git a/ui/lib/types/logs.ts b/ui/lib/types/logs.ts index 35d9a0b6d..024979ce1 100644 --- a/ui/lib/types/logs.ts +++ b/ui/lib/types/logs.ts @@ -451,6 +451,13 @@ export interface ResponsesMessage { encrypted_content?: string; // Additional tool-specific fields [key: string]: any; + output?: string | ResponsesMessageContentBlock[] | ResponsesComputerToolCallOutputData; +} + +export interface ResponsesComputerToolCallOutputData { + type: "computer_screenshot"; + file_id?: string; + image_url?: string; } // Stream options for responses diff --git a/ui/lib/types/mcp.ts b/ui/lib/types/mcp.ts index 7b7f4f8fb..9faee1dbb 100644 --- a/ui/lib/types/mcp.ts +++ b/ui/lib/types/mcp.ts @@ -13,10 +13,12 @@ export interface MCPStdioConfig { export interface MCPClientConfig { id: string; name: string; + is_code_mode_client?: boolean; connection_type: MCPConnectionType; connection_string?: string; stdio_config?: MCPStdioConfig; tools_to_execute?: string[]; + tools_to_auto_execute?: string[]; headers?: Record; } @@ -28,15 +30,19 @@ export interface MCPClient { export interface CreateMCPClientRequest { name: string; + is_code_mode_client?: boolean; connection_type: MCPConnectionType; connection_string?: string; stdio_config?: MCPStdioConfig; tools_to_execute?: string[]; + tools_to_auto_execute?: string[]; headers?: Record; } export interface UpdateMCPClientRequest { name?: string; + is_code_mode_client?: boolean; headers?: Record; tools_to_execute?: string[]; + tools_to_auto_execute?: string[]; } diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 64418aa2e..be549bb27 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -91,11 +91,11 @@ export const s3BucketConfigSchema = z.object({ bucket_name: z.string().min(1, "Bucket name is required"), prefix: z.string().optional(), is_default: z.boolean().optional(), -}) +}); export const batchS3ConfigSchema = z.object({ buckets: z.array(s3BucketConfigSchema).optional(), -}) +}); // Bedrock key config schema export const bedrockKeyConfigSchema = z @@ -461,6 +461,9 @@ export const coreConfigSchema = z.object({ allow_direct_keys: z.boolean().default(false), allowed_origins: z.array(z.string()).default(["*"]), max_request_body_size_mb: z.number().min(1).default(100), + mcp_agent_depth: z.number().min(1).default(10), + mcp_tool_execution_timeout: z.number().min(1).default(30), + mcp_code_mode_binding_level: z.enum(["server", "tool"]).default("server"), }); // Bifrost config schema @@ -599,7 +602,13 @@ export const maximFormSchema = z.object({ // MCP Client update schema export const mcpClientUpdateSchema = z.object({ - name: z.string().min(1, "Name is required"), + is_code_mode_client: z.boolean().optional(), + name: z + .string() + .min(1, "Name is required") + .refine((val) => !val.includes("-"), { message: "Client name cannot contain hyphens" }) + .refine((val) => !val.includes(" "), { message: "Client name cannot contain spaces" }) + .refine((val) => !/^[0-9]/.test(val), { message: "Client name cannot start with a number" }), headers: z.record(z.string(), z.string()).optional(), tools_to_execute: z .array(z.string()) @@ -619,10 +628,28 @@ export const mcpClientUpdateSchema = z.object({ }, { message: "Duplicate tool names are not allowed" }, ), + tools_to_auto_execute: z + .array(z.string()) + .optional() + .refine( + (tools) => { + if (!tools || tools.length === 0) return true; + const hasWildcard = tools.includes("*"); + return !hasWildcard || tools.length === 1; + }, + { message: "Wildcard '*' cannot be combined with other tool names" }, + ) + .refine( + (tools) => { + if (!tools) return true; + return tools.length === new Set(tools).size; + }, + { message: "Duplicate tool names are not allowed" }, + ), }); // Global proxy type schema -export const globalProxyTypeSchema = z.enum(['http', 'socks5', 'tcp']); +export const globalProxyTypeSchema = z.enum(["http", "socks5", "tcp"]); // Global proxy configuration schema export const globalProxyConfigSchema = z @@ -649,8 +676,8 @@ export const globalProxyConfigSchema = z return true; }, { - message: 'Proxy URL is required when proxy is enabled', - path: ['url'], + message: "Proxy URL is required when proxy is enabled", + path: ["url"], }, ) .refine( @@ -667,8 +694,8 @@ export const globalProxyConfigSchema = z return true; }, { - message: 'Must be a valid URL (e.g., http://proxy.example.com:8080)', - path: ['url'], + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], }, ); diff --git a/ui/lib/utils/validation.ts b/ui/lib/utils/validation.ts index 7b77cb378..aed9c326e 100644 --- a/ui/lib/utils/validation.ts +++ b/ui/lib/utils/validation.ts @@ -371,7 +371,11 @@ function isValidWildcardOrigin(origin: string): boolean { * @returns Object with validation result and invalid origins */ export function validateOrigins(origins: string[]): { isValid: boolean; invalidOrigins: string[] } { - const invalidOrigins = origins?.filter((origin) => !isValidOrigin(origin)) || []; + if (!origins || origins.length === 0) { + return { isValid: true, invalidOrigins: [] }; + } + + const invalidOrigins = origins.filter((origin) => !isValidOrigin(origin)); return { isValid: invalidOrigins.length === 0,