diff --git a/bifrost.go b/bifrost.go index 995664122..3e95a2c6b 100644 --- a/bifrost.go +++ b/bifrost.go @@ -19,26 +19,19 @@ const ( ChatCompletionRequest RequestType = "chat_completion" ) -// Request represents a generic request for text or chat completion -type Request struct { - Model string - Input RequestInput - Params *interfaces.ModelParameters +type ChannelMessage struct { + interfaces.BifrostRequest Response chan *interfaces.CompletionResult Err chan error Type RequestType } -type RequestInput struct { - StringInput *string - MessageInput *[]interfaces.Message -} - // Bifrost manages providers and maintains infinite open channels type Bifrost struct { account interfaces.Account - providers []interfaces.Provider // list of processed providers - requestQueues map[interfaces.SupportedModelProvider]chan Request // provider request queues + providers []interfaces.Provider // list of processed providers + plugins []interfaces.Plugin + requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues wg map[interfaces.SupportedModelProvider]*sync.WaitGroup } @@ -70,7 +63,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro return fmt.Errorf("failed to get keys for provider: %v", err) } - queue := make(chan Request, concurrencyAndBuffer.BufferSize) // Buffered channel per provider + queue := make(chan ChannelMessage, concurrencyAndBuffer.BufferSize) // Buffered channel per provider bifrost.requestQueues[provider.GetProviderKey()] = queue @@ -86,8 +79,8 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro } // Initializes infinite listening channels for each provider -func Init(account interfaces.Account) (*Bifrost, error) { - bifrost := &Bifrost{account: account} +func Init(account interfaces.Account, plugins []interfaces.Plugin) (*Bifrost, error) { + bifrost := &Bifrost{account: account, plugins: plugins} bifrost.wg = make(map[interfaces.SupportedModelProvider]*sync.WaitGroup) providerKeys, err := bifrost.account.GetInitiallyConfiguredProviderKeys() @@ -95,7 +88,7 @@ func Init(account interfaces.Account) (*Bifrost, error) { return nil, err } - bifrost.requestQueues = make(map[interfaces.SupportedModelProvider]chan Request) + bifrost.requestQueues = make(map[interfaces.SupportedModelProvider]chan ChannelMessage) // Create buffered channels for each provider and start workers for _, providerKey := range providerKeys { @@ -162,7 +155,7 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(provider interfaces.Provid return supportedKeys[len(supportedKeys)-1].Value, nil } -func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan Request) { +func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan ChannelMessage) { defer bifrost.wg[provider.GetProviderKey()].Done() for req := range queue { @@ -201,8 +194,8 @@ func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key interfaces.Supp return nil, fmt.Errorf("no provider found for key: %s", key) } -func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelProvider) (chan Request, error) { - var queue chan Request +func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelProvider) (chan ChannelMessage, error) { + var queue chan ChannelMessage var exists bool if queue, exists = bifrost.requestQueues[providerKey]; !exists { @@ -216,7 +209,7 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr return queue, nil } -func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, model, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { +func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest) (*interfaces.CompletionResult, error) { queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { return nil, err @@ -225,24 +218,44 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo responseChan := make(chan *interfaces.CompletionResult) errorChan := make(chan error) - queue <- Request{ - Model: model, - Input: RequestInput{StringInput: &text}, - Params: params, - Response: responseChan, - Err: errorChan, - Type: TextCompletionRequest, + for _, plugin := range bifrost.plugins { + if req.PluginParams == nil { + req.PluginParams = make(map[string]interface{}) + } + + req, err = plugin.PreHook(req) + + if err != nil { + return nil, err + } + } + + queue <- ChannelMessage{ + BifrostRequest: *req, + Response: responseChan, + Err: errorChan, + Type: TextCompletionRequest, } select { case result := <-responseChan: + result.PluginParams = req.PluginParams + + for _, plugin := range bifrost.plugins { + result, err = plugin.PostHook(result) + + if err != nil { + return nil, err + } + } + return result, nil case err := <-errorChan: return nil, err } } -func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, model string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { +func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest) (*interfaces.CompletionResult, error) { queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { return nil, err @@ -251,18 +264,38 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo responseChan := make(chan *interfaces.CompletionResult) errorChan := make(chan error) - queue <- Request{ - Model: model, - Input: RequestInput{MessageInput: &messages}, - Params: params, - Response: responseChan, - Err: errorChan, - Type: ChatCompletionRequest, + for _, plugin := range bifrost.plugins { + if req.PluginParams == nil { + req.PluginParams = make(map[string]interface{}) + } + + req, err = plugin.PreHook(req) + + if err != nil { + return nil, err + } + } + + queue <- ChannelMessage{ + BifrostRequest: *req, + Response: responseChan, + Err: errorChan, + Type: ChatCompletionRequest, } // Wait for response select { case result := <-responseChan: + result.PluginParams = req.PluginParams + + for _, plugin := range bifrost.plugins { + result, err = plugin.PostHook(result) + + if err != nil { + return nil, err + } + } + return result, nil case err := <-errorChan: return nil, err diff --git a/go.mod b/go.mod index 55ecf3c5a..0145bb251 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,5 @@ module bifrost go 1.21.1 require github.com/joho/godotenv v1.5.1 + +require github.com/maximhq/maxim-go v0.1.1 // indirect diff --git a/go.sum b/go.sum index d61b19e1a..ebeda6c0c 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/maximhq/maxim-go v0.1.1 h1:69uUQjjDPmUGcKg/M4/3AO0fbD+70Agt66pH/UCsI5M= +github.com/maximhq/maxim-go v0.1.1/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= diff --git a/interfaces/plugin.go b/interfaces/plugin.go new file mode 100644 index 000000000..b983ca17c --- /dev/null +++ b/interfaces/plugin.go @@ -0,0 +1,18 @@ +package interfaces + +type RequestInput struct { + StringInput *string + MessageInput *[]Message +} + +type BifrostRequest struct { + Model string + Input RequestInput + Params *ModelParameters + PluginParams map[string]interface{} +} + +type Plugin interface { + PreHook(req *BifrostRequest) (*BifrostRequest, error) + PostHook(result *CompletionResult) (*CompletionResult, error) +} diff --git a/interfaces/provider.go b/interfaces/provider.go index 71ca26cae..495e706ae 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -124,7 +124,8 @@ type CompletionResult struct { Cost *LLMInteractionCost `json:"cost"` Model string `json:"model"` Created string `json:"created"` - ModelParams *interface{} `json:"modelParams"` + Params *interface{} `json:"modelParams"` + PluginParams map[string]interface{} `json:"-"` Trace *struct { Input interface{} `json:"input"` Output interface{} `json:"output"` diff --git a/tests/account.go b/tests/account.go index 2533bc710..3af3543f1 100644 --- a/tests/account.go +++ b/tests/account.go @@ -10,12 +10,12 @@ import ( type BaseAccount struct{} // GetInitiallyConfiguredProviderKeys returns all provider keys -func (ba *BaseAccount) GetInitiallyConfiguredProviderKeys() ([]interfaces.SupportedModelProvider, error) { +func (baseAccount *BaseAccount) GetInitiallyConfiguredProviderKeys() ([]interfaces.SupportedModelProvider, error) { return []interfaces.SupportedModelProvider{interfaces.OpenAI, interfaces.Anthropic}, nil } // GetKeysForProvider returns all keys associated with a provider -func (ba *BaseAccount) GetKeysForProvider(provider interfaces.Provider) ([]interfaces.Key, error) { +func (baseAccount *BaseAccount) GetKeysForProvider(provider interfaces.Provider) ([]interfaces.Key, error) { switch provider.GetProviderKey() { case interfaces.OpenAI: return []interfaces.Key{ @@ -39,7 +39,7 @@ func (ba *BaseAccount) GetKeysForProvider(provider interfaces.Provider) ([]inter } // GetConcurrencyAndBufferSizeForProvider returns the concurrency and buffer size settings for a provider -func (ba *BaseAccount) GetConcurrencyAndBufferSizeForProvider(provider interfaces.Provider) (*interfaces.ConcurrencyAndBufferSize, error) { +func (baseAccount *BaseAccount) GetConcurrencyAndBufferSizeForProvider(provider interfaces.Provider) (*interfaces.ConcurrencyAndBufferSize, error) { switch provider.GetProviderKey() { case interfaces.OpenAI: return &interfaces.ConcurrencyAndBufferSize{ diff --git a/tests/anthropic_test.go b/tests/anthropic_test.go index 952a0c16e..d830ed285 100644 --- a/tests/anthropic_test.go +++ b/tests/anthropic_test.go @@ -18,13 +18,20 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { } go func() { - config := interfaces.ModelParameters{ + params := interfaces.ModelParameters{ ExtraParams: map[string]interface{}{ "max_tokens_to_sample": 4096, }, } + text := "Hello world!" - result, err := bifrost.TextCompletionRequest(interfaces.Anthropic, "claude-2.1", "Hello world!", &config) + result, err := bifrost.TextCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{ + Model: "claude-2.1", + Input: interfaces.RequestInput{ + StringInput: &text, + }, + Params: ¶ms, + }) if err != nil { fmt.Println("Error:", err) } else { @@ -32,7 +39,7 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { } }() - config := interfaces.ModelParameters{ + params := interfaces.ModelParameters{ ExtraParams: map[string]interface{}{ "max_tokens": 4096, }, @@ -48,7 +55,14 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { Content: &msg, }, } - result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, "claude-3-7-sonnet-20250219", messages, &config) + result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{ + Model: "claude-3-7-sonnet-20250219", + Input: interfaces.RequestInput{ + MessageInput: &messages, + }, + Params: ¶ms, + }) + if err != nil { fmt.Printf("Error in Anthropic request %d: %v\n", index+1, err) } else { diff --git a/tests/openai_test.go b/tests/openai_test.go index 1e82a31d5..5c8615de1 100644 --- a/tests/openai_test.go +++ b/tests/openai_test.go @@ -10,9 +10,17 @@ import ( // setupOpenAIRequests sends multiple test requests to OpenAI func setupOpenAIRequests(bifrost *bifrost.Bifrost) { + text := "Hello world!" + // Text completion request go func() { - result, err := bifrost.TextCompletionRequest(interfaces.OpenAI, "gpt-4o-mini", "Hello world!", nil) + result, err := bifrost.TextCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ + Model: "gpt-4o-mini", + Input: interfaces.RequestInput{ + StringInput: &text, + }, + Params: nil, + }) if err != nil { fmt.Println("Error:", err) } else { @@ -38,7 +46,13 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { Content: &msg, }, } - result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, "gpt-4o-mini", messages, nil) + result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ + Model: "gpt-4o-mini", + Input: interfaces.RequestInput{ + MessageInput: &messages, + }, + Params: nil, + }) if err != nil { fmt.Printf("Error in OpenAI request %d: %v\n", index+1, err) } else { diff --git a/tests/plugin.go b/tests/plugin.go new file mode 100644 index 000000000..8ceaad20e --- /dev/null +++ b/tests/plugin.go @@ -0,0 +1,38 @@ +package tests + +import ( + "bifrost/interfaces" + "fmt" + "time" + + "github.com/maximhq/maxim-go" + "github.com/maximhq/maxim-go/logging" +) + +type Plugin struct { + logger *logging.Logger +} + +func (plugin *Plugin) PreHook(req *interfaces.BifrostRequest) (*interfaces.BifrostRequest, error) { + traceID := time.Now().Format("20060102_150405000") + + trace := plugin.logger.Trace(&logging.TraceConfig{ + Id: traceID, + Name: maxim.StrPtr("bifrost"), + }) + + trace.SetInput(fmt.Sprintf("New Request Incoming: %v", req)) + + req.PluginParams["traceID"] = traceID + + return req, nil +} + +func (plugin *Plugin) PostHook(res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) { + fmt.Println(res.PluginParams) + + traceID := res.PluginParams["traceID"].(string) + + plugin.logger.SetTraceOutput(traceID, fmt.Sprintf("Response: %v", res)) + return res, nil +} diff --git a/tests/setup.go b/tests/setup.go index 00b2eec99..d088ad4f4 100644 --- a/tests/setup.go +++ b/tests/setup.go @@ -2,9 +2,14 @@ package tests import ( "bifrost" + "bifrost/interfaces" + "fmt" "log" + "os" "github.com/joho/godotenv" + "github.com/maximhq/maxim-go" + "github.com/maximhq/maxim-go/logging" ) func loadEnv() { @@ -14,12 +19,32 @@ func loadEnv() { } } +func getPlugin() (interfaces.Plugin, error) { + loadEnv() + + mx := maxim.Init(&maxim.MaximSDKConfig{ApiKey: os.Getenv("MAXIM_API_KEY")}) + + logger, err := mx.GetLogger(&logging.LoggerConfig{Id: os.Getenv("MAXIM_LOGGER_ID")}) + if err != nil { + return nil, err + } + + plugin := &Plugin{logger} + + return plugin, nil +} + func getBifrost() (*bifrost.Bifrost, error) { loadEnv() account := BaseAccount{} + plugin, err := getPlugin() + if err != nil { + fmt.Println("Error setting up the plugin:", err) + return nil, err + } - bifrost, err := bifrost.Init(&account) + bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}) if err != nil { return nil, err }