Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -1800,15 +1800,10 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) {
return tools[i].Name < tools[j].Name
})

state := schemas.MCPConnectionStateConnected
if client.Conn == nil {
state = schemas.MCPConnectionStateDisconnected
}

clientsInConfig = append(clientsInConfig, schemas.MCPClient{
Config: client.ExecutionConfig,
Tools: tools,
State: state,
State: client.State,
})
}

Expand Down Expand Up @@ -1937,6 +1932,7 @@ func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecution
return nil
}


// PROVIDER MANAGEMENT

// createBaseProvider creates a provider based on the base provider type
Expand Down
27 changes: 24 additions & 3 deletions core/mcp/clientmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ func (m *MCPManager) removeClientUnsafe(id string) error {
return fmt.Errorf("client %s not found", id)
}

logger.Info(fmt.Sprintf("%s Disconnecting MCP client: %s", MCPLogPrefix, client.ExecutionConfig.Name))
logger.Info(fmt.Sprintf("%s Disconnecting MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name))

// Stop health monitoring for this client
m.healthMonitorManager.StopMonitoring(id)

// Cancel SSE context if present (required for proper SSE cleanup)
if client.CancelFunc != nil {
Expand All @@ -149,7 +152,7 @@ func (m *MCPManager) removeClientUnsafe(id string) error {
// This handles cleanup for all transport types (HTTP, STDIO, SSE)
if client.Conn != nil {
if err := client.Conn.Close(); err != nil {
logger.Error("%s Failed to close MCP client %s: %v", MCPLogPrefix, client.ExecutionConfig.Name, err)
logger.Error("%s Failed to close MCP server '%s': %v", MCPLogPrefix, client.ExecutionConfig.Name, err)
}
client.Conn = nil
}
Expand Down Expand Up @@ -400,6 +403,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error {
// Store the external client connection and details
client.Conn = externalClient
client.ConnectionInfo = connectionInfo
client.State = schemas.MCPConnectionStateConnected

// Store cancel function for SSE connections to enable proper cleanup
if config.ConnectionType == schemas.MCPConnectionTypeSSE {
Expand All @@ -411,7 +415,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error {
client.ToolMap[toolName] = tool
}

logger.Info(fmt.Sprintf("%s Connected to MCP client: %s", MCPLogPrefix, config.Name))
logger.Info(fmt.Sprintf("%s Connected to MCP server '%s'", MCPLogPrefix, config.Name))
} else {
// Clean up resources before returning error: client was removed during connection setup
// Cancel SSE context if it was created
Expand All @@ -427,6 +431,23 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error {
return fmt.Errorf("client %s was removed during connection setup", config.Name)
}

// Register OnConnectionLost hook for SSE connections to detect idle timeouts
if config.ConnectionType == schemas.MCPConnectionTypeSSE && externalClient != nil {
externalClient.OnConnectionLost(func(err error) {
logger.Warn(fmt.Sprintf("%s SSE connection lost for MCP server '%s': %v", MCPLogPrefix, config.Name, err))
// Update state to disconnected
m.mu.Lock()
if client, exists := m.clientMap[config.ID]; exists {
client.State = schemas.MCPConnectionStateDisconnected
}
m.mu.Unlock()
})
}

// Start health monitoring for the client
monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval)
m.healthMonitorManager.StartMonitoring(monitor)

return nil
}

Expand Down
231 changes: 231 additions & 0 deletions core/mcp/health_monitor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
package mcp

import (
"context"
"fmt"
"sync"
"time"

"github.com/maximhq/bifrost/core/schemas"
)

const (
// Health check configuration
DefaultHealthCheckInterval = 10 * time.Second // Interval between health checks
DefaultHealthCheckTimeout = 5 * time.Second // Timeout for each health check
MaxConsecutiveFailures = 5 // Number of failures before marking as unhealthy
)

// ClientHealthMonitor tracks the health status of an MCP client
type ClientHealthMonitor struct {
manager *MCPManager
clientID string
interval time.Duration
timeout time.Duration
maxConsecutiveFailures int
mu sync.Mutex
ticker *time.Ticker
ctx context.Context
cancel context.CancelFunc
isMonitoring bool
consecutiveFailures int
}

// NewClientHealthMonitor creates a new health monitor for an MCP client
func NewClientHealthMonitor(
manager *MCPManager,
clientID string,
interval time.Duration,
) *ClientHealthMonitor {
if interval == 0 {
interval = DefaultHealthCheckInterval
}

return &ClientHealthMonitor{
manager: manager,
clientID: clientID,
interval: interval,
timeout: DefaultHealthCheckTimeout,
maxConsecutiveFailures: MaxConsecutiveFailures,
isMonitoring: false,
consecutiveFailures: 0,
}
}

// Start begins monitoring the client's health in a background goroutine
func (chm *ClientHealthMonitor) Start() {
chm.mu.Lock()
defer chm.mu.Unlock()

if chm.isMonitoring {
return // Already monitoring
}

chm.isMonitoring = true
chm.ctx, chm.cancel = context.WithCancel(context.Background())
chm.ticker = time.NewTicker(chm.interval)

go chm.monitorLoop()
logger.Debug(fmt.Sprintf("%s Health monitor started for client %s (interval: %v)", MCPLogPrefix, chm.clientID, chm.interval))
}

// Stop stops monitoring the client's health
func (chm *ClientHealthMonitor) Stop() {
chm.mu.Lock()
defer chm.mu.Unlock()

if !chm.isMonitoring {
return // Not monitoring
}

chm.isMonitoring = false
if chm.ticker != nil {
chm.ticker.Stop()
}
if chm.cancel != nil {
chm.cancel()
}
logger.Debug(fmt.Sprintf("%s Health monitor stopped for client %s", MCPLogPrefix, chm.clientID))
}

// monitorLoop runs the health check loop
func (chm *ClientHealthMonitor) monitorLoop() {
for {
select {
case <-chm.ctx.Done():
return
case <-chm.ticker.C:
chm.performHealthCheck()
}
}
}

// performHealthCheck performs a health check on the client
func (chm *ClientHealthMonitor) performHealthCheck() {
// Get the client connection
chm.manager.mu.RLock()
clientState, exists := chm.manager.clientMap[chm.clientID]
chm.manager.mu.RUnlock()

if !exists {
chm.Stop()
return
}

if clientState.Conn == nil {
// Client not connected, mark as disconnected
chm.updateClientState(schemas.MCPConnectionStateDisconnected)
chm.incrementFailures()
return
}

// Perform ping with timeout
ctx, cancel := context.WithTimeout(context.Background(), chm.timeout)
defer cancel()

err := clientState.Conn.Ping(ctx)
if err != nil {
chm.incrementFailures()

// After max consecutive failures, mark as disconnected
if chm.getConsecutiveFailures() >= chm.maxConsecutiveFailures {
chm.updateClientState(schemas.MCPConnectionStateDisconnected)
}
} else {
// Health check passed
chm.resetFailures()
chm.updateClientState(schemas.MCPConnectionStateConnected)
}
}

// updateClientState updates the client's connection state
func (chm *ClientHealthMonitor) updateClientState(state schemas.MCPConnectionState) {
chm.manager.mu.Lock()
clientState, exists := chm.manager.clientMap[chm.clientID]
if !exists {
chm.manager.mu.Unlock()
return
}

// Only update if state changed
stateChanged := clientState.State != state
if stateChanged {
clientState.State = state
}
chm.manager.mu.Unlock()

// Log after releasing the lock
if stateChanged {
logger.Info(fmt.Sprintf("%s Client %s connection state changed to: %s", MCPLogPrefix, chm.clientID, state))
}
}

// incrementFailures increments the consecutive failure counter
func (chm *ClientHealthMonitor) incrementFailures() {
chm.mu.Lock()
defer chm.mu.Unlock()
chm.consecutiveFailures++
}

// resetFailures resets the consecutive failure counter
func (chm *ClientHealthMonitor) resetFailures() {
chm.mu.Lock()
defer chm.mu.Unlock()
chm.consecutiveFailures = 0
}

// getConsecutiveFailures returns the current consecutive failure count
func (chm *ClientHealthMonitor) getConsecutiveFailures() int {
chm.mu.Lock()
defer chm.mu.Unlock()
return chm.consecutiveFailures
}

// HealthMonitorManager manages all client health monitors
type HealthMonitorManager struct {
monitors map[string]*ClientHealthMonitor
mu sync.RWMutex
}

// NewHealthMonitorManager creates a new health monitor manager
func NewHealthMonitorManager() *HealthMonitorManager {
return &HealthMonitorManager{
monitors: make(map[string]*ClientHealthMonitor),
}
}

// StartMonitoring starts monitoring a specific client
func (hmm *HealthMonitorManager) StartMonitoring(monitor *ClientHealthMonitor) {
hmm.mu.Lock()
defer hmm.mu.Unlock()

// Stop any existing monitor for this client
if existing, ok := hmm.monitors[monitor.clientID]; ok {
existing.Stop()
}

hmm.monitors[monitor.clientID] = monitor
monitor.Start()
}

// StopMonitoring stops monitoring a specific client
func (hmm *HealthMonitorManager) StopMonitoring(clientID string) {
hmm.mu.Lock()
defer hmm.mu.Unlock()

if monitor, ok := hmm.monitors[clientID]; ok {
monitor.Stop()
delete(hmm.monitors, clientID)
}
}

// StopAll stops all monitoring
func (hmm *HealthMonitorManager) StopAll() {
hmm.mu.Lock()
defer hmm.mu.Unlock()

for _, monitor := range hmm.monitors {
monitor.Stop()
}
hmm.monitors = make(map[string]*ClientHealthMonitor)
}
Loading