diff --git a/bifrost.go b/bifrost.go index 36db4579b..a331a69de 100644 --- a/bifrost.go +++ b/bifrost.go @@ -6,6 +6,7 @@ import ( "math/rand" "os" "os/signal" + "slices" "sync" "syscall" "time" @@ -35,14 +36,15 @@ type Bifrost struct { plugins []interfaces.Plugin requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues waitGroups map[interfaces.SupportedModelProvider]*sync.WaitGroup + logger interfaces.Logger } -func createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) (interfaces.Provider, error) { +func (bifrost *Bifrost) createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) (interfaces.Provider, error) { switch providerKey { case interfaces.OpenAI: - return providers.NewOpenAIProvider(config), nil + return providers.NewOpenAIProvider(config, bifrost.logger), nil case interfaces.Anthropic: - return providers.NewAnthropicProvider(config), nil + return providers.NewAnthropicProvider(config, bifrost.logger), nil case interfaces.Bedrock: return providers.NewBedrockProvider(config), nil case interfaces.Cohere: @@ -71,12 +73,12 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro // Start specified number of workers bifrost.waitGroups[providerKey] = &sync.WaitGroup{} - provider, err := createProviderFromProviderKey(providerKey, config) + provider, err := bifrost.createProviderFromProviderKey(providerKey, config) if err != nil { return fmt.Errorf("failed to get provider for the given key: %v", err) } - for i := 0; i < providerConfig.ConcurrencyAndBufferSize.Concurrency; i++ { + for range providerConfig.ConcurrencyAndBufferSize.Concurrency { bifrost.waitGroups[providerKey].Add(1) go bifrost.processRequests(provider, queue) } @@ -85,7 +87,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro } // Initializes infinite listening channels for each provider -func Init(account interfaces.Account, plugins []interfaces.Plugin) (*Bifrost, error) { +func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interfaces.Logger) (*Bifrost, error) { bifrost := &Bifrost{account: account, plugins: plugins} bifrost.waitGroups = make(map[interfaces.SupportedModelProvider]*sync.WaitGroup) @@ -94,6 +96,11 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin) (*Bifrost, er return nil, err } + if logger == nil { + logger = NewDefaultLogger(interfaces.LogLevelInfo) + } + bifrost.logger = logger + bifrost.requestQueues = make(map[interfaces.SupportedModelProvider]chan ChannelMessage) // Create buffered channels for each provider and start workers @@ -104,7 +111,7 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin) (*Bifrost, er } if err := bifrost.prepareProvider(providerKey, config); err != nil { - fmt.Printf("failed to prepare provider: %v", err) + bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider: %v", err)) } } @@ -124,11 +131,8 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.Sup // filter out keys which dont support the model var supportedKeys []interfaces.Key for _, key := range keys { - for _, supportedModel := range key.Models { - if supportedModel == model { - supportedKeys = append(supportedKeys, key) - break - } + if slices.Contains(key.Models, model) { + supportedKeys = append(supportedKeys, key) } } @@ -200,7 +204,7 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan } } - fmt.Println("Worker for provider", provider.GetProviderKey(), "exiting...") + bifrost.logger.Debug(fmt.Sprintf("Worker for provider %s exiting...", provider.GetProviderKey())) } func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key interfaces.SupportedModelProvider) (interfaces.Provider, error) { @@ -331,7 +335,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo // Shutdown gracefully stops all workers when triggered func (bifrost *Bifrost) Shutdown() { - fmt.Println("\n[BIFROST] Graceful Shutdown Initiated - Closing all request channels...") + bifrost.logger.Info("[BIFROST] Graceful Shutdown Initiated - Closing all request channels...") // Close all provider queues to signal workers to stop for _, queue := range bifrost.requestQueues { diff --git a/interfaces/logger.go b/interfaces/logger.go new file mode 100644 index 000000000..2d97af4c1 --- /dev/null +++ b/interfaces/logger.go @@ -0,0 +1,26 @@ +package interfaces + +// LogLevel represents the severity level of a log message +type LogLevel string + +const ( + LogLevelDebug LogLevel = "debug" + LogLevelInfo LogLevel = "info" + LogLevelWarn LogLevel = "warn" + LogLevelError LogLevel = "error" +) + +// Logger defines the interface for logging operations +type Logger interface { + // Debug logs a debug level message + Debug(msg string) + + // Info logs an info level message + Info(msg string) + + // Warn logs a warning level message + Warn(msg string) + + // Error logs an error level message + Error(err error) +} diff --git a/logger.go b/logger.go new file mode 100644 index 000000000..f9002190c --- /dev/null +++ b/logger.go @@ -0,0 +1,62 @@ +package bifrost + +import ( + "fmt" + "os" + "time" + + "github.com/maximhq/bifrost/interfaces" +) + +// DefaultLogger implements the Logger interface with stdout printing +type DefaultLogger struct { + level interfaces.LogLevel +} + +// NewDefaultLogger creates a new DefaultLogger instance +func NewDefaultLogger(level interfaces.LogLevel) *DefaultLogger { + return &DefaultLogger{ + level: level, + } +} + +// formatMessage formats the log message with timestamp and level +func (logger *DefaultLogger) formatMessage(level interfaces.LogLevel, msg string, err error) string { + timestamp := time.Now().Format(time.RFC3339) + baseMsg := fmt.Sprintf("[BIFROST-%s] %s: %s", timestamp, level, msg) + if err != nil { + return fmt.Sprintf("%s (error: %v)", baseMsg, err) + } + return baseMsg +} + +// Debug logs a debug level message +func (logger *DefaultLogger) Debug(msg string) { + if logger.level == interfaces.LogLevelDebug { + fmt.Fprintln(os.Stdout, logger.formatMessage(interfaces.LogLevelDebug, msg, nil)) + } +} + +// Info logs an info level message +func (logger *DefaultLogger) Info(msg string) { + if logger.level == interfaces.LogLevelDebug || logger.level == interfaces.LogLevelInfo { + fmt.Fprintln(os.Stdout, logger.formatMessage(interfaces.LogLevelInfo, msg, nil)) + } +} + +// Warn logs a warning level message +func (logger *DefaultLogger) Warn(msg string) { + if logger.level == interfaces.LogLevelDebug || logger.level == interfaces.LogLevelInfo || logger.level == interfaces.LogLevelWarn { + fmt.Fprintln(os.Stdout, logger.formatMessage(interfaces.LogLevelWarn, msg, nil)) + } +} + +// Error logs an error level message +func (logger *DefaultLogger) Error(err error) { + fmt.Fprintln(os.Stderr, logger.formatMessage(interfaces.LogLevelError, "", err)) +} + +// SetLevel sets the logging level +func (logger *DefaultLogger) SetLevel(level interfaces.LogLevel) { + logger.level = level +} diff --git a/providers/anthropic.go b/providers/anthropic.go index 272750d98..74b40e382 100644 --- a/providers/anthropic.go +++ b/providers/anthropic.go @@ -51,12 +51,14 @@ type AnthropicChatResponse struct { // AnthropicProvider implements the Provider interface for Anthropic's Claude API type AnthropicProvider struct { + logger interfaces.Logger client *fasthttp.Client } // NewAnthropicProvider creates a new AnthropicProvider instance -func NewAnthropicProvider(config *interfaces.ProviderConfig) *AnthropicProvider { +func NewAnthropicProvider(config *interfaces.ProviderConfig, logger interfaces.Logger) *AnthropicProvider { return &AnthropicProvider{ + logger: logger, client: &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), diff --git a/providers/openai.go b/providers/openai.go index 949990840..c3b5cf2ed 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -22,12 +22,14 @@ type OpenAIResponse struct { // OpenAIProvider implements the Provider interface for OpenAI type OpenAIProvider struct { + logger interfaces.Logger client *fasthttp.Client } // NewOpenAIProvider creates a new OpenAI provider instance -func NewOpenAIProvider(config *interfaces.ProviderConfig) *OpenAIProvider { +func NewOpenAIProvider(config *interfaces.ProviderConfig, logger interfaces.Logger) *OpenAIProvider { return &OpenAIProvider{ + logger: logger, client: &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), diff --git a/tests/setup.go b/tests/setup.go index 812a6ccf3..f7eb058fc 100644 --- a/tests/setup.go +++ b/tests/setup.go @@ -45,7 +45,7 @@ func getBifrost() (*bifrost.Bifrost, error) { return nil, err } - bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}) + bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}, nil) if err != nil { return nil, err }