diff --git a/core/bifrost.go b/core/bifrost.go index 26e8e5e5e..529dd458b 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -1800,15 +1800,10 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { return tools[i].Name < tools[j].Name }) - state := schemas.MCPConnectionStateConnected - if client.Conn == nil { - state = schemas.MCPConnectionStateDisconnected - } - clientsInConfig = append(clientsInConfig, schemas.MCPClient{ Config: client.ExecutionConfig, Tools: tools, - State: state, + State: client.State, }) } @@ -1937,6 +1932,7 @@ func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecution return nil } + // PROVIDER MANAGEMENT // createBaseProvider creates a provider based on the base provider type diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go index b0456c5f9..fe7392abb 100644 --- a/core/mcp/clientmanager.go +++ b/core/mcp/clientmanager.go @@ -137,7 +137,10 @@ func (m *MCPManager) removeClientUnsafe(id string) error { return fmt.Errorf("client %s not found", id) } - logger.Info(fmt.Sprintf("%s Disconnecting MCP client: %s", MCPLogPrefix, client.ExecutionConfig.Name)) + logger.Info(fmt.Sprintf("%s Disconnecting MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name)) + + // Stop health monitoring for this client + m.healthMonitorManager.StopMonitoring(id) // Cancel SSE context if present (required for proper SSE cleanup) if client.CancelFunc != nil { @@ -149,7 +152,7 @@ func (m *MCPManager) removeClientUnsafe(id string) error { // This handles cleanup for all transport types (HTTP, STDIO, SSE) if client.Conn != nil { if err := client.Conn.Close(); err != nil { - logger.Error("%s Failed to close MCP client %s: %v", MCPLogPrefix, client.ExecutionConfig.Name, err) + logger.Error("%s Failed to close MCP server '%s': %v", MCPLogPrefix, client.ExecutionConfig.Name, err) } client.Conn = nil } @@ -400,6 +403,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { // Store the external client connection and details client.Conn = externalClient client.ConnectionInfo = connectionInfo + client.State = schemas.MCPConnectionStateConnected // Store cancel function for SSE connections to enable proper cleanup if config.ConnectionType == schemas.MCPConnectionTypeSSE { @@ -411,7 +415,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { client.ToolMap[toolName] = tool } - logger.Info(fmt.Sprintf("%s Connected to MCP client: %s", MCPLogPrefix, config.Name)) + logger.Info(fmt.Sprintf("%s Connected to MCP server '%s'", MCPLogPrefix, config.Name)) } else { // Clean up resources before returning error: client was removed during connection setup // Cancel SSE context if it was created @@ -427,6 +431,23 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { return fmt.Errorf("client %s was removed during connection setup", config.Name) } + // Register OnConnectionLost hook for SSE connections to detect idle timeouts + if config.ConnectionType == schemas.MCPConnectionTypeSSE && externalClient != nil { + externalClient.OnConnectionLost(func(err error) { + logger.Warn(fmt.Sprintf("%s SSE connection lost for MCP server '%s': %v", MCPLogPrefix, config.Name, err)) + // Update state to disconnected + m.mu.Lock() + if client, exists := m.clientMap[config.ID]; exists { + client.State = schemas.MCPConnectionStateDisconnected + } + m.mu.Unlock() + }) + } + + // Start health monitoring for the client + monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval) + m.healthMonitorManager.StartMonitoring(monitor) + return nil } diff --git a/core/mcp/health_monitor.go b/core/mcp/health_monitor.go new file mode 100644 index 000000000..6a55938fe --- /dev/null +++ b/core/mcp/health_monitor.go @@ -0,0 +1,231 @@ +package mcp + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +const ( + // Health check configuration + DefaultHealthCheckInterval = 10 * time.Second // Interval between health checks + DefaultHealthCheckTimeout = 5 * time.Second // Timeout for each health check + MaxConsecutiveFailures = 5 // Number of failures before marking as unhealthy +) + +// ClientHealthMonitor tracks the health status of an MCP client +type ClientHealthMonitor struct { + manager *MCPManager + clientID string + interval time.Duration + timeout time.Duration + maxConsecutiveFailures int + mu sync.Mutex + ticker *time.Ticker + ctx context.Context + cancel context.CancelFunc + isMonitoring bool + consecutiveFailures int +} + +// NewClientHealthMonitor creates a new health monitor for an MCP client +func NewClientHealthMonitor( + manager *MCPManager, + clientID string, + interval time.Duration, +) *ClientHealthMonitor { + if interval == 0 { + interval = DefaultHealthCheckInterval + } + + return &ClientHealthMonitor{ + manager: manager, + clientID: clientID, + interval: interval, + timeout: DefaultHealthCheckTimeout, + maxConsecutiveFailures: MaxConsecutiveFailures, + isMonitoring: false, + consecutiveFailures: 0, + } +} + +// Start begins monitoring the client's health in a background goroutine +func (chm *ClientHealthMonitor) Start() { + chm.mu.Lock() + defer chm.mu.Unlock() + + if chm.isMonitoring { + return // Already monitoring + } + + chm.isMonitoring = true + chm.ctx, chm.cancel = context.WithCancel(context.Background()) + chm.ticker = time.NewTicker(chm.interval) + + go chm.monitorLoop() + logger.Debug(fmt.Sprintf("%s Health monitor started for client %s (interval: %v)", MCPLogPrefix, chm.clientID, chm.interval)) +} + +// Stop stops monitoring the client's health +func (chm *ClientHealthMonitor) Stop() { + chm.mu.Lock() + defer chm.mu.Unlock() + + if !chm.isMonitoring { + return // Not monitoring + } + + chm.isMonitoring = false + if chm.ticker != nil { + chm.ticker.Stop() + } + if chm.cancel != nil { + chm.cancel() + } + logger.Debug(fmt.Sprintf("%s Health monitor stopped for client %s", MCPLogPrefix, chm.clientID)) +} + +// monitorLoop runs the health check loop +func (chm *ClientHealthMonitor) monitorLoop() { + for { + select { + case <-chm.ctx.Done(): + return + case <-chm.ticker.C: + chm.performHealthCheck() + } + } +} + +// performHealthCheck performs a health check on the client +func (chm *ClientHealthMonitor) performHealthCheck() { + // Get the client connection + chm.manager.mu.RLock() + clientState, exists := chm.manager.clientMap[chm.clientID] + chm.manager.mu.RUnlock() + + if !exists { + chm.Stop() + return + } + + if clientState.Conn == nil { + // Client not connected, mark as disconnected + chm.updateClientState(schemas.MCPConnectionStateDisconnected) + chm.incrementFailures() + return + } + + // Perform ping with timeout + ctx, cancel := context.WithTimeout(context.Background(), chm.timeout) + defer cancel() + + err := clientState.Conn.Ping(ctx) + if err != nil { + chm.incrementFailures() + + // After max consecutive failures, mark as disconnected + if chm.getConsecutiveFailures() >= chm.maxConsecutiveFailures { + chm.updateClientState(schemas.MCPConnectionStateDisconnected) + } + } else { + // Health check passed + chm.resetFailures() + chm.updateClientState(schemas.MCPConnectionStateConnected) + } +} + +// updateClientState updates the client's connection state +func (chm *ClientHealthMonitor) updateClientState(state schemas.MCPConnectionState) { + chm.manager.mu.Lock() + clientState, exists := chm.manager.clientMap[chm.clientID] + if !exists { + chm.manager.mu.Unlock() + return + } + + // Only update if state changed + stateChanged := clientState.State != state + if stateChanged { + clientState.State = state + } + chm.manager.mu.Unlock() + + // Log after releasing the lock + if stateChanged { + logger.Info(fmt.Sprintf("%s Client %s connection state changed to: %s", MCPLogPrefix, chm.clientID, state)) + } +} + +// incrementFailures increments the consecutive failure counter +func (chm *ClientHealthMonitor) incrementFailures() { + chm.mu.Lock() + defer chm.mu.Unlock() + chm.consecutiveFailures++ +} + +// resetFailures resets the consecutive failure counter +func (chm *ClientHealthMonitor) resetFailures() { + chm.mu.Lock() + defer chm.mu.Unlock() + chm.consecutiveFailures = 0 +} + +// getConsecutiveFailures returns the current consecutive failure count +func (chm *ClientHealthMonitor) getConsecutiveFailures() int { + chm.mu.Lock() + defer chm.mu.Unlock() + return chm.consecutiveFailures +} + +// HealthMonitorManager manages all client health monitors +type HealthMonitorManager struct { + monitors map[string]*ClientHealthMonitor + mu sync.RWMutex +} + +// NewHealthMonitorManager creates a new health monitor manager +func NewHealthMonitorManager() *HealthMonitorManager { + return &HealthMonitorManager{ + monitors: make(map[string]*ClientHealthMonitor), + } +} + +// StartMonitoring starts monitoring a specific client +func (hmm *HealthMonitorManager) StartMonitoring(monitor *ClientHealthMonitor) { + hmm.mu.Lock() + defer hmm.mu.Unlock() + + // Stop any existing monitor for this client + if existing, ok := hmm.monitors[monitor.clientID]; ok { + existing.Stop() + } + + hmm.monitors[monitor.clientID] = monitor + monitor.Start() +} + +// StopMonitoring stops monitoring a specific client +func (hmm *HealthMonitorManager) StopMonitoring(clientID string) { + hmm.mu.Lock() + defer hmm.mu.Unlock() + + if monitor, ok := hmm.monitors[clientID]; ok { + monitor.Stop() + delete(hmm.monitors, clientID) + } +} + +// StopAll stops all monitoring +func (hmm *HealthMonitorManager) StopAll() { + hmm.mu.Lock() + defer hmm.mu.Unlock() + + for _, monitor := range hmm.monitors { + monitor.Stop() + } + hmm.monitors = make(map[string]*ClientHealthMonitor) +} diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go index 01b10b0cb..c7ec34d25 100644 --- a/core/mcp/mcp.go +++ b/core/mcp/mcp.go @@ -39,12 +39,13 @@ const ( // It provides a bridge between Bifrost and various MCP servers, supporting // both local tool hosting and external MCP server connections. type MCPManager struct { - ctx context.Context - toolsHandler *ToolsManager // Handler for MCP tools - server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based) - clientMap map[string]*schemas.MCPClientState // Map of MCP client names to their configurations - mu sync.RWMutex // Read-write mutex for thread-safe operations - serverRunning bool // Track whether local MCP server is running + ctx context.Context + toolsManager *ToolsManager // Handler for MCP tools + server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based) + clientMap map[string]*schemas.MCPClientState // Map of MCP client names to their configurations + mu sync.RWMutex // Read-write mutex for thread-safe operations + serverRunning bool // Track whether local MCP server is running + healthMonitorManager *HealthMonitorManager // Manager for client health monitors } // MCPToolFunction is a generic function type for handling tool calls with typed arguments. @@ -75,10 +76,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, logger schemas } // Creating new instance manager := &MCPManager{ - ctx: ctx, - clientMap: make(map[string]*schemas.MCPClientState), + ctx: ctx, + clientMap: make(map[string]*schemas.MCPClientState), + healthMonitorManager: NewHealthMonitorManager(), } - manager.toolsHandler = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc) + manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc) // Process client configs: create client map entries and establish connections if len(config.ClientConfigs) > 0 { for _, clientConfig := range config.ClientConfigs { @@ -102,11 +104,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, logger schemas // Returns: // - *schemas.BifrostRequest: The request with tools added func (m *MCPManager) AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { - return m.toolsHandler.ParseAndAddToolsToRequest(ctx, req) + return m.toolsManager.ParseAndAddToolsToRequest(ctx, req) } func (m *MCPManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { - return m.toolsHandler.GetAvailableTools(ctx) + return m.toolsManager.GetAvailableTools(ctx) } // ExecuteChatTool executes a single tool call and returns the result as a chat message. @@ -129,7 +131,7 @@ func (m *MCPManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { // - *schemas.ChatMessage: The result message containing tool execution output // - error: Any error that occurred during tool execution func (m *MCPManager) ExecuteChatTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { - return m.toolsHandler.ExecuteChatTool(ctx, toolCall) + return m.toolsManager.ExecuteChatTool(ctx, toolCall) } // ExecuteResponsesTool executes a single tool call and returns the result as a responses message. @@ -141,7 +143,7 @@ func (m *MCPManager) ExecuteChatTool(ctx context.Context, toolCall schemas.ChatA // - *schemas.ResponsesMessage: The result message containing tool execution output // - error: Any error that occurred during tool execution func (m *MCPManager) ExecuteResponsesTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, error) { - return m.toolsHandler.ExecuteResponsesTool(ctx, toolCall) + return m.toolsManager.ExecuteResponsesTool(ctx, toolCall) } // UpdateToolManagerConfig updates the configuration for the tool manager. @@ -150,7 +152,7 @@ func (m *MCPManager) ExecuteResponsesTool(ctx context.Context, toolCall *schemas // Parameters: // - config: The new tool manager configuration to apply func (m *MCPManager) UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) { - m.toolsHandler.UpdateConfig(config) + m.toolsManager.UpdateConfig(config) } // CheckAndExecuteAgentForChatRequest checks if the chat response contains tool calls, @@ -195,7 +197,7 @@ func (m *MCPManager) CheckAndExecuteAgentForChatRequest( return response, nil } // Execute agent mode - return m.toolsHandler.ExecuteAgentForChatRequest(ctx, req, response, makeReq) + return m.toolsManager.ExecuteAgentForChatRequest(ctx, req, response, makeReq) } // CheckAndExecuteAgentForResponsesRequest checks if the responses response contains tool calls, @@ -246,7 +248,7 @@ func (m *MCPManager) CheckAndExecuteAgentForResponsesRequest( return response, nil } // Execute agent mode - return m.toolsHandler.ExecuteAgentForResponsesRequest(ctx, req, response, makeReq) + return m.toolsManager.ExecuteAgentForResponsesRequest(ctx, req, response, makeReq) } // Cleanup performs cleanup of all MCP resources including clients and local server. @@ -257,6 +259,9 @@ func (m *MCPManager) CheckAndExecuteAgentForResponsesRequest( // Returns: // - error: Always returns nil, but maintains error interface for consistency func (m *MCPManager) Cleanup() error { + // Stop all health monitors first + m.healthMonitorManager.StopAll() + m.mu.Lock() defer m.mu.Unlock() diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index 663eaa1f4..f54998cc5 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -100,6 +100,7 @@ type MCPClientState struct { ToolMap map[string]ChatTool // Available tools mapped by name ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) + State MCPConnectionState // Connection state (connected, disconnected, error) } // MCPClientConnectionInfo stores metadata about how a client is connected. diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 6799eb299..56a1c94a2 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -168,7 +168,7 @@ func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { clients = append(clients, schemas.MCPClient{ Config: h.store.RedactMCPClientConfig(connectedClient.Config), Tools: sortedTools, - State: connectedClient.State, + State: connectedClient.State, // Use the state from MCPClientState }) } else { // Client is in config but not connected, mark as errored