diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..cd9b5cd25 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,30 @@ +{ + "version": "0.2.0", + "configurations": [ + + { + "name": "Launch Bifrost", + "type": "go", + "request": "launch", + "mode": "debug", + "program": "${workspaceFolder}/bifrost.go", + "args": [] + }, + { + "name": "Debug Tests", + "type": "go", + "request": "launch", + "mode": "test", + "program": "${workspaceFolder}/tests", + "args": [] + }, + { + "name": "Attach to Process", + "type": "go", + "request": "attach", + "mode": "local", + "processId": "${command:pickProcess}" + } + ] + } + \ No newline at end of file diff --git a/bifrost.go b/bifrost.go index a31e6f6b4..4e8e6368c 100644 --- a/bifrost.go +++ b/bifrost.go @@ -1,15 +1,17 @@ package bifrost import ( - "bifrost/interfaces" + "context" "fmt" - "log" "math/rand" "os" "os/signal" "sync" "syscall" "time" + + "github.com/maximhq/bifrost/interfaces" + "github.com/maximhq/bifrost/providers" ) type RequestType string @@ -19,13 +21,9 @@ const ( ChatCompletionRequest RequestType = "chat_completion" ) -// Request represents a generic request for text or chat completion -type Request struct { - Model string - //* is this okay or should we do string | Message? - Input interface{} - Params *interfaces.ModelParameters - Response chan *interfaces.CompletionResult +type ChannelMessage struct { + interfaces.BifrostRequest + Response chan *interfaces.BifrostResponse Err chan error Type RequestType } @@ -33,31 +31,53 @@ type Request struct { // Bifrost manages providers and maintains infinite open channels type Bifrost struct { account interfaces.Account - providers []interfaces.Provider // list of processed providers - requestQueues map[string]chan Request // provider request queues - wg sync.WaitGroup + providers []interfaces.Provider // list of processed providers + plugins []interfaces.Plugin + requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues + waitGroups map[interfaces.SupportedModelProvider]*sync.WaitGroup +} + +func createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) (interfaces.Provider, error) { + switch providerKey { + case interfaces.OpenAI: + return providers.NewOpenAIProvider(config), nil + case interfaces.Anthropic: + return providers.NewAnthropicProvider(config), nil + case interfaces.Bedrock: + return providers.NewBedrockProvider(config), nil + case interfaces.Cohere: + return providers.NewCohereProvider(config), nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) + } } -func (bifrost *Bifrost) prepareProvider(provider interfaces.Provider) error { - concurrency, err := bifrost.account.GetConcurrencyAndBufferSizeForProvider(provider) +func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) error { + providerConfig, err := bifrost.account.GetConfigForProvider(providerKey) if err != nil { - log.Fatalf("Failed to get concurrency and buffer size for provider: %v", err) - return err + return fmt.Errorf("failed to get config for provider: %v", err) } // Check if the provider has any keys - keys, err := bifrost.account.GetKeysForProvider(provider) + keys, err := bifrost.account.GetKeysForProvider(providerKey) if err != nil || len(keys) == 0 { - log.Fatalf("Failed to get keys for provider: %v", err) - return err + return fmt.Errorf("failed to get keys for provider: %v", err) } - queue := make(chan Request, concurrency.BufferSize) // Buffered channel per provider - bifrost.requestQueues[string(provider.GetProviderKey())] = queue + queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider + + bifrost.requestQueues[providerKey] = queue // Start specified number of workers - for i := 0; i < concurrency.Concurrency; i++ { - bifrost.wg.Add(1) + bifrost.waitGroups[providerKey] = &sync.WaitGroup{} + + provider, err := createProviderFromProviderKey(providerKey, config) + if err != nil { + return fmt.Errorf("failed to get provider for the given key: %v", err) + } + + for i := 0; i < providerConfig.ConcurrencyAndBufferSize.Concurrency; i++ { + bifrost.waitGroups[providerKey].Add(1) go bifrost.processRequests(provider, queue) } @@ -65,36 +85,40 @@ func (bifrost *Bifrost) prepareProvider(provider interfaces.Provider) error { } // Initializes infinite listening channels for each provider -func Init(account interfaces.Account) (*Bifrost, error) { - bifrost := &Bifrost{account: account} +func Init(account interfaces.Account, plugins []interfaces.Plugin) (*Bifrost, error) { + bifrost := &Bifrost{account: account, plugins: plugins} + bifrost.waitGroups = make(map[interfaces.SupportedModelProvider]*sync.WaitGroup) - providers, err := bifrost.account.GetInitiallyConfiguredProviders() + providerKeys, err := bifrost.account.GetInitiallyConfiguredProviders() if err != nil { - log.Fatalf("Failed to get initially configured providers: %v", err) return nil, err } - bifrost.requestQueues = make(map[string]chan Request) + bifrost.requestQueues = make(map[interfaces.SupportedModelProvider]chan ChannelMessage) // Create buffered channels for each provider and start workers - for _, provider := range providers { - if err := bifrost.prepareProvider(provider); err != nil { - log.Fatalf("Failed to prepare provider: %v", err) - return nil, err + for _, providerKey := range providerKeys { + config, err := bifrost.account.GetConfigForProvider(providerKey) + if err != nil { + return nil, fmt.Errorf("failed to get config for provider: %v", err) + } + + if err := bifrost.prepareProvider(providerKey, config); err != nil { + fmt.Printf("failed to prepare provider: %v", err) } } return bifrost, nil } -func (bifrost *Bifrost) SelectFromProviderKeys(provider interfaces.Provider, model string) (string, error) { - keys, err := bifrost.account.GetKeysForProvider(provider) +func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.SupportedModelProvider, model string) (string, error) { + keys, err := bifrost.account.GetKeysForProvider(providerKey) if err != nil { return "", err } if len(keys) == 0 { - return "", fmt.Errorf("no keys found for provider: %v", provider.GetProviderKey()) + return "", fmt.Errorf("no keys found for provider: %v", providerKey) } // filter out keys which dont support the model @@ -113,10 +137,10 @@ func (bifrost *Bifrost) SelectFromProviderKeys(provider interfaces.Provider, mod } // Create a new random source - ran := rand.New(rand.NewSource(time.Now().UnixNano())) + randomSource := rand.New(rand.NewSource(time.Now().UnixNano())) // Shuffle keys using the new random number generator - ran.Shuffle(len(supportedKeys), func(i, j int) { + randomSource.Shuffle(len(supportedKeys), func(i, j int) { supportedKeys[i], supportedKeys[j] = supportedKeys[j], supportedKeys[i] }) @@ -127,7 +151,7 @@ func (bifrost *Bifrost) SelectFromProviderKeys(provider interfaces.Provider, mod } // Generate a random number within total weight - r := ran.Float64() * totalWeight + r := randomSource.Float64() * totalWeight var cumulative float64 // Select the key based on weighted probability @@ -142,23 +166,31 @@ func (bifrost *Bifrost) SelectFromProviderKeys(provider interfaces.Provider, mod return supportedKeys[len(supportedKeys)-1].Value, nil } -func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan Request) { - defer bifrost.wg.Done() +func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan ChannelMessage) { + defer bifrost.waitGroups[provider.GetProviderKey()].Done() for req := range queue { - var result *interfaces.CompletionResult + var result *interfaces.BifrostResponse var err error - key, err := bifrost.SelectFromProviderKeys(provider, req.Model) + key, err := bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) if err != nil { req.Err <- err continue } if req.Type == TextCompletionRequest { - result, err = provider.TextCompletion(req.Model, key, req.Input.(string), req.Params) + if req.Input.TextInput == nil { + err = fmt.Errorf("text not provided for text completion request") + } else { + result, err = provider.TextCompletion(req.Model, key, *req.Input.TextInput, req.Params) + } } else if req.Type == ChatCompletionRequest { - result, err = provider.ChatCompletion(req.Model, key, req.Input.([]interfaces.Message), req.Params) + if req.Input.ChatInput == nil { + err = fmt.Errorf("chats not provided for chat completion request") + } else { + result, err = provider.ChatCompletion(req.Model, key, *req.Input.ChatInput, req.Params) + } } if err != nil { @@ -171,7 +203,7 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan fmt.Println("Worker for provider", provider.GetProviderKey(), "exiting...") } -func (bifrost *Bifrost) GetProviderFromProviderKey(key interfaces.SupportedModelProvider) (interfaces.Provider, error) { +func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key interfaces.SupportedModelProvider) (interfaces.Provider, error) { for _, provider := range bifrost.providers { if provider.GetProviderKey() == key { return provider, nil @@ -181,73 +213,116 @@ func (bifrost *Bifrost) GetProviderFromProviderKey(key interfaces.SupportedModel return nil, fmt.Errorf("no provider found for key: %s", key) } -func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelProvider) (chan Request, error) { - var queue chan Request +func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelProvider) (chan ChannelMessage, error) { + var queue chan ChannelMessage var exists bool - if queue, exists = bifrost.requestQueues[string(providerKey)]; !exists { - provider, err := bifrost.GetProviderFromProviderKey(providerKey) + if queue, exists = bifrost.requestQueues[providerKey]; !exists { + config, err := bifrost.account.GetConfigForProvider(providerKey) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get config for provider: %v", err) } - if err := bifrost.prepareProvider(provider); err != nil { + if err := bifrost.prepareProvider(providerKey, config); err != nil { return nil, err } - queue = bifrost.requestQueues[string(providerKey)] + queue = bifrost.requestQueues[providerKey] } return queue, nil } -func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, model, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { +func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) { + if req == nil { + return nil, fmt.Errorf("bifrost request cannot be nil") + } + queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { return nil, err } - responseChan := make(chan *interfaces.CompletionResult) + responseChan := make(chan *interfaces.BifrostResponse) errorChan := make(chan error) - queue <- Request{ - Model: model, - Input: text, - Params: params, - Response: responseChan, - Err: errorChan, - Type: TextCompletionRequest, + for _, plugin := range bifrost.plugins { + req, err = plugin.PreHook(&ctx, req) + if err != nil { + return nil, err + } + } + + if req == nil { + return nil, fmt.Errorf("bifrost request after plugin hooks cannot be nil") + } + + queue <- ChannelMessage{ + BifrostRequest: *req, + Response: responseChan, + Err: errorChan, + Type: TextCompletionRequest, } select { case result := <-responseChan: + // Run plugins in reverse order + for i := len(bifrost.plugins) - 1; i >= 0; i-- { + result, err = bifrost.plugins[i].PostHook(&ctx, result) + + if err != nil { + return nil, err + } + } + return result, nil case err := <-errorChan: return nil, err } } -func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, model string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { +func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) { + if req == nil { + return nil, fmt.Errorf("bifrost request cannot be nil") + } + queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { return nil, err } - responseChan := make(chan *interfaces.CompletionResult) + responseChan := make(chan *interfaces.BifrostResponse) errorChan := make(chan error) - queue <- Request{ - Model: model, - Input: messages, - Params: params, - Response: responseChan, - Err: errorChan, - Type: ChatCompletionRequest, + for _, plugin := range bifrost.plugins { + req, err = plugin.PreHook(&ctx, req) + if err != nil { + return nil, err + } + } + + if req == nil { + return nil, fmt.Errorf("bifrost request after pre plugin hooks cannot be nil") + } + + queue <- ChannelMessage{ + BifrostRequest: *req, + Response: responseChan, + Err: errorChan, + Type: ChatCompletionRequest, } - // Wait for response select { case result := <-responseChan: + // Run plugins in reverse order + for i := len(bifrost.plugins) - 1; i >= 0; i-- { + result, err = bifrost.plugins[i].PostHook(&ctx, result) + + if err != nil { + return nil, err + } + } + return result, nil case err := <-errorChan: return nil, err @@ -256,7 +331,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo // Shutdown gracefully stops all workers when triggered func (bifrost *Bifrost) Shutdown() { - fmt.Println("\n[Graceful Shutdown Initiated] Closing all request channels...") + fmt.Println("\n[BIFROST] Graceful Shutdown Initiated - Closing all request channels...") // Close all provider queues to signal workers to stop for _, queue := range bifrost.requestQueues { @@ -264,9 +339,9 @@ func (bifrost *Bifrost) Shutdown() { } // Wait for all workers to exit - bifrost.wg.Wait() - - fmt.Println("Bifrost has shut down gracefully.") + for _, waitGroup := range bifrost.waitGroups { + waitGroup.Wait() + } } // Cleanup handles SIGINT (Ctrl+C) to exit cleanly diff --git a/go.mod b/go.mod index 3f9ab680e..ef8d638aa 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,31 @@ -module bifrost +module github.com/maximhq/bifrost -go 1.21.1 +go 1.23.0 -require github.com/joho/godotenv v1.5.1 // indirect +toolchain go1.24.1 + +require github.com/joho/godotenv v1.5.1 + +require ( + github.com/aws/aws-sdk-go-v2 v1.36.3 + github.com/aws/aws-sdk-go-v2/config v1.29.11 + github.com/maximhq/maxim-go v0.1.1 + github.com/valyala/fasthttp v1.58.0 +) + +require ( + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.64 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.2 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.17 // indirect + github.com/aws/smithy-go v1.22.2 // indirect + github.com/klauspost/compress v1.17.11 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect +) diff --git a/go.sum b/go.sum index d61b19e1a..06db2183d 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,40 @@ +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.11 h1:/hkJIxaQzFQy0ebFjG5NHmAcLCrvNSuXeHnxLfeCz1Y= +github.com/aws/aws-sdk-go-v2/config v1.29.11/go.mod h1:OFPRZVQxC4mKqy2Go6Cse/m9NOStAo6YaMvAcTMUROg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.64 h1:NH4RAQJEXBDQDUudTqMNHdyyEVa5CvMn0tQicqv48jo= +github.com/aws/aws-sdk-go-v2/credentials v1.17.64/go.mod h1:tUoJfj79lzEcalHDbyNkpnZZTRg/2ayYOK/iYnRfPbo= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.2 h1:pdgODsAhGo4dvzC3JAG5Ce0PX8kWXrTZGx+jxADD+5E= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.2/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.2 h1:wK8O+j2dOolmpNVY1EWIbLgxrGCHJKVPm08Hv/u80M8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.2/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.17 h1:PZV5W8yk4OtH1JAuhV2PXwwO9v5G5Aoj+eMCn4T+1Kc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.17/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= +github.com/maximhq/maxim-go v0.1.1 h1:69uUQjjDPmUGcKg/M4/3AO0fbD+70Agt66pH/UCsI5M= +github.com/maximhq/maxim-go v0.1.1/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.58.0 h1:GGB2dWxSbEprU9j0iMJHgdKYJVDyjrOwF9RE59PbRuE= +github.com/valyala/fasthttp v1.58.0/go.mod h1:SYXvHHaFp7QZHGKSHmoMipInhrI5StHrhDTYVEjK/Kw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= diff --git a/interfaces/account.go b/interfaces/account.go index c12f1e099..27d88d4b4 100644 --- a/interfaces/account.go +++ b/interfaces/account.go @@ -1,18 +1,14 @@ package interfaces -type ConcurrencyAndBufferSize struct { - Concurrency int `json:"concurrency"` - BufferSize int `json:"buffer_size"` -} - type Key struct { Value string `json:"value"` Models []string `json:"models"` Weight float64 `json:"weight"` } +// TODO one get config method type Account interface { - GetInitiallyConfiguredProviders() ([]Provider, error) - GetKeysForProvider(provider Provider) ([]Key, error) - GetConcurrencyAndBufferSizeForProvider(provider Provider) (ConcurrencyAndBufferSize, error) + GetInitiallyConfiguredProviders() ([]SupportedModelProvider, error) + GetKeysForProvider(providerKey SupportedModelProvider) ([]Key, error) + GetConfigForProvider(providerKey SupportedModelProvider) (*ProviderConfig, error) } diff --git a/interfaces/bifrost.go b/interfaces/bifrost.go new file mode 100644 index 000000000..f1d63395e --- /dev/null +++ b/interfaces/bifrost.go @@ -0,0 +1,229 @@ +package interfaces + +// ModelChatMessageRole represents the role of a chat message +type ModelChatMessageRole string + +const ( + RoleAssistant ModelChatMessageRole = "assistant" + RoleUser ModelChatMessageRole = "user" + RoleSystem ModelChatMessageRole = "system" + RoleChatbot ModelChatMessageRole = "chatbot" + RoleTool ModelChatMessageRole = "tool" +) + +type SupportedModelProvider string + +const ( + OpenAI SupportedModelProvider = "openai" + Azure SupportedModelProvider = "azure" + HuggingFace SupportedModelProvider = "huggingface" + Anthropic SupportedModelProvider = "anthropic" + Google SupportedModelProvider = "google" + Groq SupportedModelProvider = "groq" + Bedrock SupportedModelProvider = "bedrock" + Maxim SupportedModelProvider = "maxim" + Cohere SupportedModelProvider = "cohere" + Ollama SupportedModelProvider = "ollama" + Lmstudio SupportedModelProvider = "lmstudio" +) + +//* Request Structs + +type RequestInput struct { + TextInput *string + ChatInput *[]Message +} + +type BifrostRequest struct { + Model string + Input RequestInput + Params *ModelParameters +} + +// ModelParameters represents the parameters for model requests +type ModelParameters struct { + ToolChoice *ToolChoice `json:"tool_choice,omitempty"` + Tools *[]Tool `json:"tools,omitempty"` + + // Common model parameters + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + StopSequences *[]string `json:"stop_sequences,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls"` + + // Dynamic parameters + ExtraParams map[string]interface{} `json:"-"` +} + +type FunctionParameters struct { + Type string `json:"type,"` + Required []string `json:"required"` + Properties map[string]interface{} `json:"properties"` +} + +// Function represents a function definition for tool calls +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters FunctionParameters `json:"parameters"` +} + +// Tool represents a tool that can be used with the model +type Tool struct { + ID *string `json:"id,omitempty"` + Type string `json:"type"` + Function Function `json:"function"` +} + +// combined tool choices for all providers +type ToolChoiceType string + +const ( + ToolChoiceNone ToolChoiceType = "none" + ToolChoiceAuto ToolChoiceType = "auto" + ToolChoiceAny ToolChoiceType = "any" + ToolChoiceTool ToolChoiceType = "tool" + ToolChoiceRequired ToolChoiceType = "required" +) + +type ToolChoiceFunction struct { + Name string `json:"name"` +} + +type ToolChoice struct { + Type ToolChoiceType `json:"type"` + Function ToolChoiceFunction `json:"function"` +} + +type Message struct { + //* strict check for roles + Role ModelChatMessageRole `json:"role"` + //* need to make sure either content or imagecontent is provided + Content *string `json:"content,omitempty"` + ImageContent *ImageContent `json:"image_content,omitempty"` + ToolCalls *[]Tool `json:"tool_calls,omitempty"` +} + +type ImageContent struct { + Type string `json:"type"` + URL string `json:"url"` + MediaType string `json:"media_type"` +} + +//* Response Structs + +// LLMUsage represents token usage information +type LLMUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + TokenDetails *TokenDetails `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` +} + +type TokenDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` +} + +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` +} + +type BilledLLMUsage struct { + PromptTokens *float64 `json:"prompt_tokens,omitempty"` + CompletionTokens *float64 `json:"completion_tokens,omitempty"` + SearchUnits *float64 `json:"search_units,omitempty"` + Classifications *float64 `json:"classifications,omitempty"` +} + +type LogProb struct { + Bytes []int `json:"bytes,omitempty"` + LogProb float64 `json:"logprob"` + Token string `json:"token"` +} + +type ContentLogProb struct { + Bytes []int `json:"bytes"` + LogProb float64 `json:"logprob"` + Token string `json:"token"` + TopLogProbs []LogProb `json:"top_logprobs"` +} + +type LogProbs struct { + Content []ContentLogProb `json:"content"` + Refusal []LogProb `json:"refusal"` +} + +type FunctionCall struct { + Name *string `json:"name"` + Arguments string `json:"arguments"` // stringified json as retured by OpenAI, might not be a valid JSON always +} + +// ToolCall represents a tool call in a message +type ToolCall struct { + Type *string `json:"type,omitempty"` + ID *string `json:"id,omitempty"` + Function FunctionCall `json:"function"` +} + +type Citation struct { + StartIndex int `json:"start_index"` + EndIndex int `json:"end_index"` + Title string `json:"title"` + URL *string `json:"url,omitempty"` + Sources *interface{} `json:"sources,omitempty"` + Type *string `json:"type,omitempty"` +} + +type Annotation struct { + Type string `json:"type"` + Citation Citation `json:"url_citation"` +} + +// BifrostResponseChoiceMessage represents a choice in the completion response +type BifrostResponseChoiceMessage struct { + Role ModelChatMessageRole `json:"role"` + Content *string `json:"content,omitempty"` + Refusal *string `json:"refusal,omitempty"` + Annotations []Annotation `json:"annotations,omitempty"` + ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` +} + +// BifrostResponseChoice represents a choice in the completion result +type BifrostResponseChoice struct { + Index int `json:"index"` + Message BifrostResponseChoiceMessage `json:"message"` + FinishReason *string `json:"finish_reason,omitempty"` + StopString *string `json:"stop,omitempty"` + LogProbs *LogProbs `json:"log_probs,omitempty"` +} + +type BifrostResponseExtraFields struct { + Provider SupportedModelProvider `json:"provider"` + Params ModelParameters `json:"model_params"` + Latency *float64 `json:"latency,omitempty"` + ChatHistory *[]BifrostResponseChoiceMessage `json:"chat_history,omitempty"` + BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` + RawResponse interface{} `json:"raw_response"` +} + +// BifrostResponse represents the complete result from a model completion +type BifrostResponse struct { + ID string `json:"id"` + Object string `json:"object"` // text.completion or chat.completion + Choices []BifrostResponseChoice `json:"choices"` + Model string `json:"model"` + Created int `json:"created"` // The Unix timestamp (in seconds). + ServiceTier *string `json:"service_tier,omitempty"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` + Usage LLMUsage `json:"usage"` + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` +} diff --git a/interfaces/plugin.go b/interfaces/plugin.go new file mode 100644 index 000000000..380379f36 --- /dev/null +++ b/interfaces/plugin.go @@ -0,0 +1,8 @@ +package interfaces + +import "context" + +type Plugin interface { + PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) + PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error) +} diff --git a/interfaces/provider.go b/interfaces/provider.go index b555e1061..cb07a7596 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -1,213 +1,33 @@ package interfaces -// LLMUsage represents token usage information -type LLMUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - Latency float64 `json:"latency,omitempty"` -} - -// LLMInteractionCost represents cost information for LLM interactions -type LLMInteractionCost struct { - Input float64 `json:"input"` - Output float64 `json:"output"` - Total float64 `json:"total"` -} - -// Function represents a function definition for tool calls -type Function struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters interface{} `json:"parameters"` -} - -// Tool represents a tool that can be used with the model -type Tool struct { - Type string `json:"type"` - Function Function `json:"function"` -} - -// ModelParameters represents the parameters for model requests -type ModelParameters struct { - TestRunEntryID *string `json:"testRunEntryId,omitempty"` - PromptTools *[]string `json:"promptTools,omitempty"` - ToolChoice *string `json:"toolChoice,omitempty"` - Tools *[]Tool `json:"tools,omitempty"` - FunctionCall *string `json:"functionCall,omitempty"` - Functions *[]Function `json:"functions,omitempty"` - // Dynamic parameters - ExtraParams map[string]interface{} `json:"-"` -} - -// RequestOptions represents options for model requests -type RequestOptions struct { - UseCache bool `json:"useCache,omitempty"` - WaitForModel bool `json:"waitForModel,omitempty"` - CompletionType string `json:"CompletionType,omitempty"` -} - -// FunctionCall represents a function call in a tool call -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -// ToolCall represents a tool call in a message -type ToolCall struct { - Type string `json:"type"` - ID string `json:"id"` - Function FunctionCall `json:"function"` -} - -// ModelChatMessageRole represents the role of a chat message -type ModelChatMessageRole string - -const ( - RoleAssistant ModelChatMessageRole = "assistant" - RoleUser ModelChatMessageRole = "user" - RoleSystem ModelChatMessageRole = "system" - RoleModel ModelChatMessageRole = "model" - RoleChatbot ModelChatMessageRole = "chatbot" - RoleTool ModelChatMessageRole = "tool" -) - -// CompletionResponseChoice represents a choice in the completion response -type CompletionResponseChoice struct { - Role ModelChatMessageRole `json:"role"` - Content string `json:"content"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` -} - -// CompletionResultChoice represents a choice in the completion result -type CompletionResultChoice struct { - Index int `json:"index"` - Message CompletionResponseChoice `json:"message"` - FinishReason string `json:"finish_reason,omitempty"` - LogProbs interface{} `json:"logprobs,omitempty"` -} - -// ToolResult represents the result of a tool call -type ToolResult struct { - Role ModelChatMessageRole `json:"role"` - Content string `json:"content"` - ToolCallID string `json:"tool_call_id"` -} +// TODO third party providers -// ToolCallResult represents a single tool call result -type ToolCallResult struct { - Name string `json:"name"` - Result interface{} `json:"result"` - Type string `json:"type"` - ID string `json:"id"` +type NetworkConfig struct { + DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` } -// ToolCallResults represents a collection of tool call results -type ToolCallResults struct { - Version int `json:"version"` - Results []ToolCallResult `json:"results"` +type MetaConfig struct { + SecretAccessKey *string `json:"secret_access_key,omitempty"` + Region *string `json:"region,omitempty"` + SessionToken *string `json:"session_token,omitempty"` + ARN *string `json:"arn,omitempty"` + InferenceProfiles map[string]string `json:"inference_profiles,omitempty"` } -// CompletionResult represents the complete result from a model completion -type CompletionResult struct { - Error *struct { - Code string `json:"code"` - Message string `json:"message"` - Type string `json:"type"` - } `json:"error,omitempty"` - ID string `json:"id"` - Choices []CompletionResultChoice `json:"choices"` - ToolCallResult interface{} `json:"tool_call_result,omitempty"` - ToolCallResults *ToolCallResults `json:"toolCallResults,omitempty"` - Provider SupportedModelProvider `json:"provider,omitempty"` - Usage LLMUsage `json:"usage"` - Cost *LLMInteractionCost `json:"cost,omitempty"` - Model string `json:"model,omitempty"` - Created string `json:"created,omitempty"` - ModelParams interface{} `json:"modelParams,omitempty"` - Trace *struct { - Input interface{} `json:"input"` - Output interface{} `json:"output,omitempty"` - } `json:"trace,omitempty"` - RetrievedContext interface{} `json:"retrievedContext,omitempty"` - VariableBoundRetrievals map[string]interface{} `json:"variableBoundRetrievals,omitempty"` -} - -type SupportedModelProvider string - -const ( - OpenAI SupportedModelProvider = "openai" - Azure SupportedModelProvider = "azure" - HuggingFace SupportedModelProvider = "huggingface" - Anthropic SupportedModelProvider = "anthropic" - Google SupportedModelProvider = "google" - Groq SupportedModelProvider = "groq" - Bedrock SupportedModelProvider = "bedrock" - Maxim SupportedModelProvider = "maxim" - Cohere SupportedModelProvider = "cohere" - Ollama SupportedModelProvider = "ollama" - Lmstudio SupportedModelProvider = "lmstudio" -) - -type Role string - -const ( - UserRole Role = "user" - AssistantRole Role = "assistant" - SystemRole Role = "system" -) - -type Message struct { - //* strict check for roles - Role Role `json:"role"` - //* need to make sure either content or imagecontent is provided - Content *string `json:"content"` - ImageContent *ImageContent `json:"imageContent"` +type ConcurrencyAndBufferSize struct { + Concurrency int `json:"concurrency"` + BufferSize int `json:"buffer_size"` } -type ImageContent struct { - Type string `json:"type"` - ImageURL struct { - URL string `json:"url"` - } `json:"image_url"` +type ProviderConfig struct { + NetworkConfig NetworkConfig `json:"network_config"` + MetaConfig *MetaConfig `json:"meta_config,omitempty"` + ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` } -// type Content struct { -// Content *string `json:"content"` -// ImageContent *ImageContent `json:"imageContent"` -// } - -// func (content *Content) MarshalJSON() ([]byte, error) { -// if content.Content != nil { -// return []byte(*content.Content), nil -// } else if content.ImageContent != nil { -// return json.Marshal(content.ImageContent) -// } - -// return nil, fmt.Errorf("invalid content") -// } - -// func (content *Content) UnmarshalJSON(val []byte) error { -// var s any -// json.Unmarshal(val, &s) - -// switch s := s.(type) { -// case string: -// content.Content = &s -// case ImageContent: -// content.ImageContent = &s - -// default: -// return fmt.Errorf("invalid stop") -// } - -// return nil -// } - // Provider defines the interface for AI model providers type Provider interface { GetProviderKey() SupportedModelProvider - TextCompletion(model, key, text string, params *ModelParameters) (*CompletionResult, error) - ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*CompletionResult, error) + TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, error) + ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, error) } diff --git a/providers/anthropic.go b/providers/anthropic.go index c9b874b8c..5920ade65 100644 --- a/providers/anthropic.go +++ b/providers/anthropic.go @@ -1,24 +1,67 @@ package providers import ( - "bifrost/interfaces" - "bytes" "encoding/json" "fmt" - "io" - "net/http" "time" + + "github.com/maximhq/bifrost/interfaces" + "github.com/valyala/fasthttp" + + "github.com/maximhq/maxim-go" ) +type AnthropicToolChoice struct { + Type interfaces.ToolChoiceType `json:"type"` + Name *string `json:"name"` + DisableParallelToolUse *bool `json:"disable_parallel_tool_use"` +} + +type AnthropicTextResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Completion string `json:"completion"` + Model string `json:"model"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` +} + +type AnthropicChatResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Thinking string `json:"thinking,omitempty"` + ID string `json:"id"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` + } `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` +} + // AnthropicProvider implements the Provider interface for Anthropic's Claude API type AnthropicProvider struct { - client *http.Client + client *fasthttp.Client } // NewAnthropicProvider creates a new AnthropicProvider instance -func NewAnthropicProvider() *AnthropicProvider { +func NewAnthropicProvider(config *interfaces.ProviderConfig) *AnthropicProvider { return &AnthropicProvider{ - client: &http.Client{Timeout: 30 * time.Second}, + client: &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + }, } } @@ -26,9 +69,54 @@ func (provider *AnthropicProvider) GetProviderKey() interfaces.SupportedModelPro return interfaces.Anthropic } +func (provider *AnthropicProvider) PrepareTextCompletionParams(params map[string]interface{}) map[string]interface{} { + // Check if there is a key entry for max_tokens + if maxTokens, exists := params["max_tokens"]; exists { + // Check if max_tokens_to_sample is already present + if _, exists := params["max_tokens_to_sample"]; !exists { + // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample + params["max_tokens_to_sample"] = maxTokens + } + delete(params, "max_tokens") + } + return params +} + +func (provider *AnthropicProvider) PrepareToolChoices(params map[string]interface{}) map[string]interface{} { + toolChoice, exists := params["tool_choice"] + if !exists { + return params + } + + switch tc := toolChoice.(type) { + case interfaces.ToolChoice: + anthropicToolChoice := AnthropicToolChoice{ + Type: tc.Type, + Name: &tc.Function.Name, + } + + parallelToolCalls, exists := params["parallel_tool_calls"] + if !exists { + return params + } + + switch parallelTC := parallelToolCalls.(type) { + case bool: + disableParallel := !parallelTC + anthropicToolChoice.DisableParallelToolUse = &disableParallel + + delete(params, "parallel_tool_calls") + } + + params["tool_choice"] = anthropicToolChoice + } + + return params +} + // TextCompletion implements text completion using Anthropic's API -func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { - preparedParams := PrepareParams(params) +func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { + preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params)) // Merge additional parameters requestBody := MergeConfig(map[string]interface{}{ @@ -43,88 +131,72 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param } // Create the request with the JSON body - req, err := http.NewRequest("POST", "https://api.anthropic.com/v1/complete", bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI("https://api.anthropic.com/v1/complete") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") req.Header.Set("x-api-key", key) req.Header.Set("anthropic-version", "2023-06-01") + req.SetBody(jsonData) // Send the request - resp, err := provider.client.Do(req) - if err != nil { + if err := provider.client.Do(req, resp); err != nil { return nil, fmt.Errorf("error sending request: %v", err) } - defer resp.Body.Close() - // Read the response body - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response: %v", err) + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + return nil, fmt.Errorf("anthropic error: %s", resp.Body()) } - // Check for error response - if resp.StatusCode != http.StatusOK { - var errorResp struct { - Type string `json:"type"` - Error struct { - Type string `json:"type"` - Message string `json:"message"` - } `json:"error"` - } - if err := json.Unmarshal(body, &errorResp); err != nil { - return nil, fmt.Errorf("error response: %s", string(body)) - } - return nil, fmt.Errorf("anthropic error: %s", errorResp.Error.Message) - } + // Read the response body + body := resp.Body() // Parse the response - var result struct { - ID string `json:"id"` - Type string `json:"type"` - Completion string `json:"completion"` - Model string `json:"model"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - } `json:"usage"` + var response AnthropicTextResponse + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("error parsing response: %v", err) } - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("error parsing response: %v", err) + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("error parsing raw response: %v", err) } // Create the completion result - completionResult := &interfaces.CompletionResult{ - ID: result.ID, - Choices: []interfaces.CompletionResultChoice{ + completionResult := &interfaces.BifrostResponse{ + ID: response.ID, + Choices: []interfaces.BifrostResponseChoice{ { Index: 0, - Message: interfaces.CompletionResponseChoice{ + Message: interfaces.BifrostResponseChoiceMessage{ Role: interfaces.RoleAssistant, - Content: result.Completion, + Content: &response.Completion, }, }, }, Usage: interfaces.LLMUsage{ - PromptTokens: result.Usage.InputTokens, - CompletionTokens: result.Usage.OutputTokens, - TotalTokens: result.Usage.InputTokens + result.Usage.OutputTokens, + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + }, + Model: response.Model, + ExtraFields: interfaces.BifrostResponseExtraFields{ + Provider: interfaces.Anthropic, + RawResponse: rawResponse, }, - Model: result.Model, - Provider: interfaces.Anthropic, } return completionResult, nil } // ChatCompletion implements chat completion using Anthropic's API -func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { - startTime := time.Now() - +func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { // Format messages for Anthropic API var formattedMessages []map[string]interface{} for _, msg := range messages { @@ -136,137 +208,125 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] preparedParams := PrepareParams(params) + // Transform tools if present + if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + var tools []map[string]interface{} + for _, tool := range *params.Tools { + tools = append(tools, map[string]interface{}{ + "name": tool.Function.Name, + "description": tool.Function.Description, + "input_schema": tool.Function.Parameters, + }) + } + + preparedParams["tools"] = tools + } + // Merge additional parameters requestBody := MergeConfig(map[string]interface{}{ "model": model, "messages": formattedMessages, }, preparedParams) - jsonData, err := json.Marshal(requestBody) + jsonBody, err := json.Marshal(requestBody) if err != nil { return nil, fmt.Errorf("error marshaling request: %v", err) } // Create request - req, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages", bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI("https://api.anthropic.com/v1/messages") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") req.Header.Set("x-api-key", key) req.Header.Set("anthropic-version", "2023-06-01") + req.SetBody(jsonBody) - // Send request - resp, err := provider.client.Do(req) - if err != nil { - return nil, fmt.Errorf("error sending request: %v", err) - } - defer resp.Body.Close() - - // Read response body - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response: %v", err) + // Make request + if err := provider.client.Do(req, resp); err != nil { + return nil, fmt.Errorf("error making request: %v", err) } - // Check for non-200 status codes - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error: %s", string(body)) + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + return nil, fmt.Errorf("anthropic error: %s", resp.Body()) } - // Calculate latency - latency := time.Since(startTime).Seconds() - - // Decode response - var anthropicResponse struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Thinking string `json:"thinking,omitempty"` - ToolUse *struct { - ID string `json:"id"` - Name string `json:"name"` - Input map[string]interface{} `json:"input"` - } `json:"tool_use,omitempty"` - } `json:"content"` - Model string `json:"model"` - StopReason string `json:"stop_reason,omitempty"` - StopSequence *string `json:"stop_sequence,omitempty"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - } `json:"usage"` + // Decode structured response + var response AnthropicChatResponse + if err := json.Unmarshal(resp.Body(), &response); err != nil { + return nil, fmt.Errorf("error decoding structured response: %v", err) } - if err := json.Unmarshal(body, &anthropicResponse); err != nil { - return nil, fmt.Errorf("error decoding response: %v", err) + // Decode raw response + var rawResponse interface{} + if err := json.Unmarshal(resp.Body(), &rawResponse); err != nil { + return nil, fmt.Errorf("error decoding raw response: %v", err) } - // Process the response into our CompletionResult format - var content string - var toolCalls []interfaces.ToolCall - var finishReason string + // Process the response into our BifrostResponse format + var choices []interfaces.BifrostResponseChoice // Process content and tool calls - for _, c := range anthropicResponse.Content { + for i, c := range response.Content { + var content string + var toolCalls []interfaces.ToolCall + switch c.Type { case "thinking": - if content == "" { - content = fmt.Sprintf("\n%s\n\n\n", c.Thinking) - } + content = c.Thinking case "text": - content += c.Text + content = c.Text case "tool_use": - if c.ToolUse != nil { - toolCalls = append(toolCalls, interfaces.ToolCall{ - Type: "function", - ID: c.ToolUse.ID, - Function: interfaces.FunctionCall{ - Name: c.ToolUse.Name, - Arguments: string(must(json.Marshal(c.ToolUse.Input))), - }, - }) - finishReason = "tool_calls" + function := interfaces.FunctionCall{ + Name: &c.Name, } + + args, err := json.Marshal(c.Input) + if err != nil { + function.Arguments = fmt.Sprintf("%v", c.Input) + } else { + function.Arguments = string(args) + } + + toolCalls = append(toolCalls, interfaces.ToolCall{ + Type: maxim.StrPtr("function"), + ID: &c.ID, + Function: function, + }) } + + choices = append(choices, interfaces.BifrostResponseChoice{ + Index: i, + Message: interfaces.BifrostResponseChoiceMessage{ + Role: interfaces.RoleAssistant, + Content: &content, + ToolCalls: &toolCalls, + }, + FinishReason: &response.StopReason, + StopString: response.StopSequence, + }) } // Create the completion result - result := &interfaces.CompletionResult{ - ID: anthropicResponse.ID, - Choices: []interfaces.CompletionResultChoice{ - { - Index: 0, - Message: interfaces.CompletionResponseChoice{ - Role: interfaces.RoleAssistant, - Content: content, - ToolCalls: toolCalls, - }, - FinishReason: finishReason, - }, - }, + result := &interfaces.BifrostResponse{ + ID: response.ID, + Choices: choices, Usage: interfaces.LLMUsage{ - PromptTokens: anthropicResponse.Usage.InputTokens, - CompletionTokens: anthropicResponse.Usage.OutputTokens, - TotalTokens: anthropicResponse.Usage.InputTokens + anthropicResponse.Usage.OutputTokens, - Latency: latency, + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + }, + Model: response.Model, + ExtraFields: interfaces.BifrostResponseExtraFields{ + Provider: interfaces.Anthropic, + RawResponse: rawResponse, }, - Model: anthropicResponse.Model, - Provider: interfaces.Anthropic, } return result, nil } - -// Helper function to handle JSON marshaling errors -func must[T any](v T, err error) T { - if err != nil { - panic(err) - } - return v -} diff --git a/providers/bedrock.go b/providers/bedrock.go new file mode 100644 index 000000000..f49a59f4f --- /dev/null +++ b/providers/bedrock.go @@ -0,0 +1,526 @@ +package providers + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/maximhq/bifrost/interfaces" +) + +type BedrockAnthropicTextResponse struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Stop string `json:"stop"` +} + +type BedrockMistralTextResponse struct { + Outputs []struct { + Text string `json:"text"` + StopReason string `json:"stop_reason"` + } `json:"outputs"` +} + +type BedrockChatResponse struct { + Metrics struct { + Latency int `json:"latencyMs"` + } `json:"metrics"` + Output struct { + Message struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + Role string `json:"role"` + } `json:"message"` + } `json:"output"` + StopReason string `json:"stopReason"` + Usage struct { + InputTokens int `json:"inputTokens"` + OutputTokens int `json:"outputTokens"` + TotalTokens int `json:"totalTokens"` + } `json:"usage"` +} + +type BedrockAnthropicSystemMessage struct { + Text string `json:"text"` +} + +type BedrockAnthropicTextMessage struct { + Type string `json:"type"` + Text string `json:"text"` +} + +type BedrockMistralContent struct { + Text string `json:"text"` +} + +type BedrockMistralChatMessage struct { + Role interfaces.ModelChatMessageRole `json:"role"` + Content []BedrockMistralContent `json:"content"` + ToolCalls *[]BedrockMistralToolCall `json:"tool_calls,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` +} + +type BedrockAnthropicImageMessage struct { + Type string `json:"type"` + Source BedrockAnthropicImageSource `json:"source"` +} + +type BedrockAnthropicImageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type BedrockMistralToolCall struct { + ID string `json:"id"` + Function interfaces.Function `json:"function"` +} + +type BedrockAnthropicToolCall struct { + ToolSpec BedrockAnthropicToolSpec `json:"toolSpec"` +} + +type BedrockAnthropicToolSpec struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema struct { + Json interface{} `json:"json"` + } `json:"inputSchema"` +} + +type BedrockProvider struct { + client *http.Client + meta *interfaces.MetaConfig +} + +func NewBedrockProvider(config *interfaces.ProviderConfig) *BedrockProvider { + return &BedrockProvider{ + client: &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)}, + meta: config.MetaConfig, + } +} + +func (provider *BedrockProvider) GetProviderKey() interfaces.SupportedModelProvider { + return interfaces.Bedrock +} + +func (provider *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey string) (*http.Request, error) { + if provider.meta == nil { + return nil, errors.New("meta config for bedrock is not provided") + } + + region := "us-east-1" + if provider.meta.Region != nil { + region = *provider.meta.Region + } + + // Create the request with the JSON body + req, err := http.NewRequest("POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + if err := SignAWSRequest(req, accessKey, *provider.meta.SecretAccessKey, provider.meta.SessionToken, region, "bedrock"); err != nil { + return nil, err + } + + return req, nil +} + +func (provider *BedrockProvider) GetTextCompletionResult(result []byte, model string) (*interfaces.BifrostResponse, error) { + switch model { + case "anthropic.claude-instant-v1:2": + fallthrough + case "anthropic.claude-v2": + fallthrough + case "anthropic.claude-v2:1": + var response BedrockAnthropicTextResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + } + + return &interfaces.BifrostResponse{ + Choices: []interfaces.BifrostResponseChoice{ + { + Index: 0, + Message: interfaces.BifrostResponseChoiceMessage{ + Role: interfaces.RoleAssistant, + Content: &response.Completion, + }, + FinishReason: &response.StopReason, + StopString: &response.Stop, + }, + }, + Model: model, + ExtraFields: interfaces.BifrostResponseExtraFields{ + Provider: interfaces.Bedrock, + }, + }, nil + + case "mistral.mixtral-8x7b-instruct-v0:1": + fallthrough + case "mistral.mistral-7b-instruct-v0:2": + fallthrough + case "mistral.mistral-large-2402-v1:0": + fallthrough + case "mistral.mistral-large-2407-v1:0": + fallthrough + case "mistral.mistral-small-2402-v1:0": + var response BedrockMistralTextResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + } + + var choices []interfaces.BifrostResponseChoice + for i, output := range response.Outputs { + choices = append(choices, interfaces.BifrostResponseChoice{ + Index: i, + Message: interfaces.BifrostResponseChoiceMessage{ + Role: interfaces.RoleAssistant, + Content: &output.Text, + }, + FinishReason: &output.StopReason, + }) + } + + return &interfaces.BifrostResponse{ + Choices: choices, + Model: model, + ExtraFields: interfaces.BifrostResponseExtraFields{ + Provider: interfaces.Bedrock, + }, + }, nil + } + + return nil, fmt.Errorf("invalid model choice: %s", model) +} + +func (provider *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Message, model string) (map[string]interface{}, error) { + switch model { + case "anthropic.claude-instant-v1:2": + fallthrough + case "anthropic.claude-v2": + fallthrough + case "anthropic.claude-v2:1": + fallthrough + case "anthropic.claude-3-sonnet-20240229-v1:0": + fallthrough + case "anthropic.claude-3-5-sonnet-20240620-v1:0": + fallthrough + case "anthropic.claude-3-5-sonnet-20241022-v2:0": + fallthrough + case "anthropic.claude-3-5-haiku-20241022-v1:0": + fallthrough + case "anthropic.claude-3-opus-20240229-v1:0": + fallthrough + case "anthropic.claude-3-7-sonnet-20250219-v1:0": + // Add system messages if present + var systemMessages []BedrockAnthropicSystemMessage + for _, msg := range messages { + if msg.Role == interfaces.RoleSystem { + //TODO handling image inputs here + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *msg.Content, + }) + } + } + + // Format messages for Bedrock API + var bedrockMessages []map[string]interface{} + for _, msg := range messages { + if msg.Role != interfaces.RoleSystem { + var content any + if msg.Content != nil { + content = BedrockAnthropicTextMessage{ + Type: "text", + Text: *msg.Content, + } + } else if msg.ImageContent != nil { + content = BedrockAnthropicImageMessage{ + Type: "image", + Source: BedrockAnthropicImageSource{ + Type: msg.ImageContent.Type, + MediaType: msg.ImageContent.MediaType, + Data: msg.ImageContent.URL, + }, + } + } + + bedrockMessages = append(bedrockMessages, map[string]interface{}{ + "role": msg.Role, + "content": []interface{}{content}, + }) + } + } + + body := map[string]interface{}{ + "messages": bedrockMessages, + } + + if len(systemMessages) > 0 { + var messages []string + for _, message := range systemMessages { + messages = append(messages, message.Text) + } + + body["system"] = strings.Join(messages, " ") + } + + return body, nil + + case "mistral.mistral-large-2402-v1:0": + fallthrough + case "mistral.mistral-large-2407-v1:0": + var bedrockMessages []BedrockMistralChatMessage + for _, msg := range messages { + var filteredToolCalls []BedrockMistralToolCall + if msg.ToolCalls != nil { + for _, toolCall := range *msg.ToolCalls { + filteredToolCalls = append(filteredToolCalls, BedrockMistralToolCall{ + ID: *toolCall.ID, + Function: toolCall.Function, + }) + } + } + + message := BedrockMistralChatMessage{ + Role: msg.Role, + Content: []BedrockMistralContent{ + {Text: *msg.Content}, + }, + } + + if len(filteredToolCalls) > 0 { + message.ToolCalls = &filteredToolCalls + } + + bedrockMessages = append(bedrockMessages, message) + } + + body := map[string]interface{}{ + "messages": bedrockMessages, + } + + return body, nil + } + + return nil, fmt.Errorf("invalid model choice: %s", model) +} + +func (provider *BedrockProvider) GetChatCompletionTools(params *interfaces.ModelParameters, model string) []BedrockAnthropicToolCall { + var tools []BedrockAnthropicToolCall + + switch model { + case "anthropic.claude-instant-v1:2": + fallthrough + case "anthropic.claude-v2": + fallthrough + case "anthropic.claude-v2:1": + fallthrough + case "anthropic.claude-3-sonnet-20240229-v1:0": + fallthrough + case "anthropic.claude-3-5-sonnet-20240620-v1:0": + fallthrough + case "anthropic.claude-3-5-sonnet-20241022-v2:0": + fallthrough + case "anthropic.claude-3-5-haiku-20241022-v1:0": + fallthrough + case "anthropic.claude-3-opus-20240229-v1:0": + fallthrough + case "anthropic.claude-3-7-sonnet-20250219-v1:0": + for _, tool := range *params.Tools { + tools = append(tools, BedrockAnthropicToolCall{ + ToolSpec: BedrockAnthropicToolSpec{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: struct { + Json interface{} `json:"json"` + }{ + Json: tool.Function.Parameters, + }, + }, + }) + } + } + + return tools +} + +func (provider *BedrockProvider) PrepareTextCompletionParams(params map[string]interface{}, model string) map[string]interface{} { + switch model { + case "anthropic.claude-instant-v1:2": + fallthrough + case "anthropic.claude-v2": + fallthrough + case "anthropic.claude-v2:1": + // Check if there is a key entry for max_tokens + if maxTokens, exists := params["max_tokens"]; exists { + // Check if max_tokens_to_sample is already present + if _, exists := params["max_tokens_to_sample"]; !exists { + // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample + params["max_tokens_to_sample"] = maxTokens + } + delete(params, "max_tokens") + } + } + return params +} + +func (provider *BedrockProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { + preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params), model) + + requestBody := MergeConfig(map[string]interface{}{ + "prompt": text, + }, preparedParams) + + // Marshal the request body + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + // Create the signed request with correct operation name + req, err := provider.PrepareReq(fmt.Sprintf("%s/invoke", model), jsonData, key) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + // Execute the request + resp, err := provider.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %v", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bedrock API error: %s", string(body)) + } + + result, err := provider.GetTextCompletionResult(body, model) + if err != nil { + return nil, fmt.Errorf("failed to parse response body: %v", err) + } + + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("failed to parse raw response: %v", err) + } + + result.ExtraFields.RawResponse = rawResponse + + return result, nil +} + +func (provider *BedrockProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { + messageBody, err := provider.PrepareChatCompletionMessages(messages, model) + if err != nil { + return nil, fmt.Errorf("error preparing messages: %v", err) + } + + preparedParams := PrepareParams(params) + + // Transform tools if present + if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + preparedParams["tools"] = provider.GetChatCompletionTools(params, model) + } + + requestBody := MergeConfig(messageBody, preparedParams) + + // Marshal the request body + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + // Format the path with proper model identifier + path := fmt.Sprintf("%s/converse", model) + + if provider.meta != nil && provider.meta.InferenceProfiles != nil { + if inferenceProfileId, ok := provider.meta.InferenceProfiles[model]; ok { + if provider.meta.ARN != nil { + encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.ARN, inferenceProfileId)) + path = fmt.Sprintf("%s/converse", encodedModelIdentifier) + } + } + } + + // Create the signed request + req, err := provider.PrepareReq(path, jsonData, key) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + // Execute the request + resp, err := provider.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %v", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bedrock API error: %s", string(body)) + } + + var response BedrockChatResponse + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + } + + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("failed to parse raw response: %v", err) + } + + var choices []interfaces.BifrostResponseChoice + for i, choice := range response.Output.Message.Content { + choices = append(choices, interfaces.BifrostResponseChoice{ + Index: i, + Message: interfaces.BifrostResponseChoiceMessage{ + Role: interfaces.RoleAssistant, + Content: &choice.Text, + }, + FinishReason: &response.StopReason, + }) + } + + latency := float64(response.Metrics.Latency) + + result := &interfaces.BifrostResponse{ + Choices: choices, + Usage: interfaces.LLMUsage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.TotalTokens, + }, + Model: model, + + ExtraFields: interfaces.BifrostResponseExtraFields{ + Latency: &latency, + Provider: interfaces.Bedrock, + RawResponse: rawResponse, + }, + } + + return result, nil +} diff --git a/providers/cohere.go b/providers/cohere.go new file mode 100644 index 000000000..4f14e3d67 --- /dev/null +++ b/providers/cohere.go @@ -0,0 +1,290 @@ +package providers + +import ( + "encoding/json" + "fmt" + "slices" + "time" + + "github.com/maximhq/bifrost/interfaces" + "github.com/valyala/fasthttp" +) + +// CohereParameterDefinition represents a parameter definition for a Cohere tool +type CohereParameterDefinition struct { + Type string `json:"type"` + Description *string `json:"description,omitempty"` + Required bool `json:"required"` +} + +// CohereTool represents a tool definition for Cohere API +type CohereTool struct { + Name string `json:"name"` + Description string `json:"description"` + ParameterDefinitions map[string]CohereParameterDefinition `json:"parameter_definitions"` +} + +type CohereToolCall struct { + Name string `json:"name"` + Parameters interface{} `json:"parameters"` +} + +// CohereChatResponse represents the response from Cohere's chat API +type CohereChatResponse struct { + ResponseID string `json:"response_id"` + Text string `json:"text"` + GenerationID string `json:"generation_id"` + ChatHistory []struct { + Role interfaces.ModelChatMessageRole `json:"role"` + Message string `json:"message"` + ToolCalls []CohereToolCall `json:"tool_calls"` + } `json:"chat_history"` + FinishReason string `json:"finish_reason"` + Meta struct { + APIVersion struct { + Version string `json:"version"` + } `json:"api_version"` + BilledUnits struct { + InputTokens float64 `json:"input_tokens"` + OutputTokens float64 `json:"output_tokens"` + } `json:"billed_units"` + Tokens struct { + InputTokens float64 `json:"input_tokens"` + OutputTokens float64 `json:"output_tokens"` + } `json:"tokens"` + } `json:"meta"` + ToolCalls []CohereToolCall `json:"tool_calls"` +} + +// OpenAIProvider implements the Provider interface for OpenAI +type CohereProvider struct { + client *fasthttp.Client +} + +// NewOpenAIProvider creates a new OpenAI provider instance +func NewCohereProvider(config *interfaces.ProviderConfig) *CohereProvider { + return &CohereProvider{ + client: &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + }, + } +} + +func (provider *CohereProvider) GetProviderKey() interfaces.SupportedModelProvider { + return interfaces.Cohere +} + +func (provider *CohereProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { + return nil, fmt.Errorf("text completion is not supported by Cohere") +} + +func (provider *CohereProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { + // Get the last message and chat history + lastMessage := messages[len(messages)-1] + chatHistory := messages[:len(messages)-1] + + // Transform chat history + var cohereHistory []map[string]interface{} + for _, msg := range chatHistory { + cohereHistory = append(cohereHistory, map[string]interface{}{ + "role": msg.Role, + "message": msg.Content, + }) + } + + preparedParams := PrepareParams(params) + + // Prepare request body + requestBody := MergeConfig(map[string]interface{}{ + "message": lastMessage.Content, + "chat_history": cohereHistory, + "model": model, + }, preparedParams) + + // Add tools if present + if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + var tools []CohereTool + for _, tool := range *params.Tools { + parameterDefinitions := make(map[string]CohereParameterDefinition) + params := tool.Function.Parameters + for name, prop := range tool.Function.Parameters.Properties { + propMap, ok := prop.(map[string]interface{}) + if ok { + paramDef := CohereParameterDefinition{ + Required: slices.Contains(params.Required, name), + } + + if typeStr, ok := propMap["type"].(string); ok { + paramDef.Type = typeStr + } + + if desc, ok := propMap["description"].(string); ok { + paramDef.Description = &desc + } + + parameterDefinitions[name] = paramDef + } + } + + tools = append(tools, CohereTool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + ParameterDefinitions: parameterDefinitions, + }) + } + requestBody["tools"] = tools + } + + // Marshal request body + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI("https://api.cohere.ai/v1/chat") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key) + req.SetBody(jsonBody) + + // Make request + if err := provider.client.Do(req, resp); err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + return nil, fmt.Errorf("cohere error: %s", resp.Body()) + } + + // Read response body + body := resp.Body() + + // Decode response + var response CohereChatResponse + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + } + + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("failed to parse raw response: %v", err) + } + + // Transform tool calls if present + var toolCalls []interfaces.ToolCall + if response.ToolCalls != nil { + for _, tool := range response.ToolCalls { + function := interfaces.FunctionCall{ + Name: &tool.Name, + } + + args, err := json.Marshal(tool.Parameters) + if err != nil { + function.Arguments = fmt.Sprintf("%v", tool.Parameters) + } else { + function.Arguments = string(args) + } + + toolCalls = append(toolCalls, interfaces.ToolCall{ + Function: function, + }) + } + } + + // Get role and content from the last message in chat history + var role interfaces.ModelChatMessageRole + var content string + if len(response.ChatHistory) > 0 { + lastMsg := response.ChatHistory[len(response.ChatHistory)-1] + role = lastMsg.Role + content = lastMsg.Message + } else { + role = interfaces.RoleChatbot + content = response.Text + } + + // Create completion result + result := &interfaces.BifrostResponse{ + ID: response.ResponseID, + Choices: []interfaces.BifrostResponseChoice{ + { + Index: 0, + Message: interfaces.BifrostResponseChoiceMessage{ + Role: role, + Content: &content, + ToolCalls: &toolCalls, + }, + FinishReason: &response.FinishReason, + }, + }, + Usage: interfaces.LLMUsage{ + PromptTokens: int(response.Meta.Tokens.InputTokens), + CompletionTokens: int(response.Meta.Tokens.OutputTokens), + TotalTokens: int(response.Meta.Tokens.InputTokens + response.Meta.Tokens.OutputTokens), + }, + + ExtraFields: interfaces.BifrostResponseExtraFields{ + Provider: interfaces.Cohere, + BilledUsage: &interfaces.BilledLLMUsage{ + PromptTokens: float64Ptr(response.Meta.BilledUnits.InputTokens), + CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens), + }, + ChatHistory: convertChatHistory(response.ChatHistory), + RawResponse: rawResponse, + }, + Model: model, + } + + return result, nil +} + +// Helper function to convert chat history to the correct type +func convertChatHistory(history []struct { + Role interfaces.ModelChatMessageRole `json:"role"` + Message string `json:"message"` + ToolCalls []CohereToolCall `json:"tool_calls"` +}) *[]interfaces.BifrostResponseChoiceMessage { + converted := make([]interfaces.BifrostResponseChoiceMessage, len(history)) + for i, msg := range history { + var toolCalls []interfaces.ToolCall + if msg.ToolCalls != nil { + for _, tool := range msg.ToolCalls { + function := interfaces.FunctionCall{ + Name: &tool.Name, + } + + args, err := json.Marshal(tool.Parameters) + if err != nil { + function.Arguments = fmt.Sprintf("%v", tool.Parameters) + } else { + function.Arguments = string(args) + } + + toolCalls = append(toolCalls, interfaces.ToolCall{ + Function: function, + }) + } + } + converted[i] = interfaces.BifrostResponseChoiceMessage{ + Role: msg.Role, + Content: &msg.Message, + ToolCalls: &toolCalls, + } + } + return &converted +} + +// Helper function to create a pointer to a float64 +func float64Ptr(f float64) *float64 { + return &f +} diff --git a/providers/openai.go b/providers/openai.go index bdc1385d4..3c12ec768 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -1,24 +1,38 @@ package providers import ( - "bifrost/interfaces" - "bytes" "encoding/json" "fmt" - "net/http" "time" + + "github.com/maximhq/bifrost/interfaces" + "github.com/valyala/fasthttp" ) +type OpenAIResponse struct { + ID string `json:"id"` + Object string `json:"object"` // text.completion or chat.completion + Choices []interfaces.BifrostResponseChoice `json:"choices"` + Model string `json:"model"` + Created int `json:"created"` // The Unix timestamp (in seconds). + ServiceTier *string `json:"service_tier"` + SystemFingerprint *string `json:"system_fingerprint"` + Usage interfaces.LLMUsage `json:"usage"` +} + // OpenAIProvider implements the Provider interface for OpenAI type OpenAIProvider struct { - //* Do we even need it? - client *http.Client + client *fasthttp.Client } // NewOpenAIProvider creates a new OpenAI provider instance -func NewOpenAIProvider() *OpenAIProvider { +func NewOpenAIProvider(config *interfaces.ProviderConfig) *OpenAIProvider { return &OpenAIProvider{ - client: &http.Client{Timeout: time.Second * 30}, + client: &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + }, } } @@ -27,35 +41,14 @@ func (provider *OpenAIProvider) GetProviderKey() interfaces.SupportedModelProvid } // TextCompletion performs text completion -func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { +func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { return nil, fmt.Errorf("text completion is not supported by OpenAI") } -// sanitizeParameters cleans up the parameters for OpenAI -func (provider *OpenAIProvider) sanitizeParameters(params *interfaces.ModelParameters) *interfaces.ModelParameters { - sanitized := params - if sanitized == nil { - return nil - } - - if params.ExtraParams != nil { - // For logprobs, if it's disabled, we remove top_logprobs - if _, exists := params.ExtraParams["logprobs"]; !exists { - delete(sanitized.ExtraParams, "top_logprobs") - } - } - - return sanitized -} - -// ChatCompletion implements chat completion using OpenAI's API -func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { - startTime := time.Now() - +func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { // Format messages for OpenAI API var openAIMessages []map[string]interface{} for _, msg := range messages { - var content any if msg.Content != nil { content = msg.Content @@ -69,8 +62,6 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int }) } - // Sanitize parameters - params = provider.sanitizeParameters(params) preparedParams := PrepareParams(params) requestBody := MergeConfig(map[string]interface{}{ @@ -84,75 +75,55 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int } // Create request - req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(jsonBody)) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - // Add headers - req.Header.Set("Content-Type", "application/json") + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI("https://api.openai.com/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") req.Header.Set("Authorization", "Bearer "+key) + req.SetBody(jsonBody) // Make request - resp, err := provider.client.Do(req) - if err != nil { + if err := provider.client.Do(req, resp); err != nil { return nil, fmt.Errorf("error making request: %v", err) } - defer resp.Body.Close() - - latency := time.Since(startTime).Seconds() // Handle error response - if resp.StatusCode != http.StatusOK { - var errorResp struct { - Error struct { - Message string `json:"message"` - Type string `json:"type"` - Param any `json:"param"` - Code string `json:"code"` - } `json:"error"` - } - if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil { - return nil, fmt.Errorf("error decoding error response: %v", err) - } - return nil, fmt.Errorf("OpenAI error: %s", errorResp.Error.Message) + if resp.StatusCode() != fasthttp.StatusOK { + return nil, fmt.Errorf("OpenAI error: %s", resp.Body()) } - // Decode response - var rawResult struct { - ID string `json:"id"` - Choices []interfaces.CompletionResultChoice `json:"choices"` - Usage interfaces.LLMUsage `json:"usage"` - Model string `json:"model"` - Created interface{} `json:"created"` - } + body := resp.Body() - if err := json.NewDecoder(resp.Body).Decode(&rawResult); err != nil { - return nil, fmt.Errorf("error decoding response: %v", err) + // Decode structured response + var response OpenAIResponse + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("error decoding structured response: %v", err) } - // Convert the raw result to CompletionResult - result := &interfaces.CompletionResult{ - ID: rawResult.ID, - Choices: rawResult.Choices, - Usage: rawResult.Usage, - Model: rawResult.Model, + // Decode raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("error decoding raw response: %v", err) } - // Handle the created field conversion - if rawResult.Created != nil { - switch v := rawResult.Created.(type) { - case float64: - // Convert Unix timestamp to string - result.Created = fmt.Sprintf("%d", int64(v)) - case string: - result.Created = v - } + result := &interfaces.BifrostResponse{ + ID: response.ID, + Choices: response.Choices, + Object: response.Object, + Usage: response.Usage, + ServiceTier: response.ServiceTier, + SystemFingerprint: response.SystemFingerprint, + Created: response.Created, + Model: response.Model, + ExtraFields: interfaces.BifrostResponseExtraFields{ + Provider: interfaces.OpenAI, + RawResponse: rawResponse, + }, } - // Add provider-specific information - result.Provider = interfaces.OpenAI - result.Usage.Latency = latency - return result, nil } diff --git a/providers/utils.go b/providers/utils.go index bd925de75..9acbef7ab 100644 --- a/providers/utils.go +++ b/providers/utils.go @@ -1,8 +1,24 @@ package providers import ( - "bifrost/interfaces" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" "reflect" + "strings" + "time" + + "github.com/maximhq/bifrost/interfaces" + + "maps" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" ) // MergeConfig merges default config with custom parameters @@ -35,7 +51,7 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} { typ := val.Type() // Iterate through all fields - for i := 0; i < val.NumField(); i++ { + for i := range val.NumField() { field := val.Field(i) fieldType := typ.Field(i) @@ -50,6 +66,9 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} { continue } + // Strip out ,omitempty and others from the tag + jsonTag = strings.Split(jsonTag, ",")[0] + // Handle pointer fields if field.Kind() == reflect.Ptr && !field.IsNil() { flatParams[jsonTag] = field.Elem().Interface() @@ -57,9 +76,64 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} { } // Handle ExtraParams - for k, v := range params.ExtraParams { - flatParams[k] = v - } + maps.Copy(flatParams, params.ExtraParams) return flatParams } + +func SignAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken *string, region, service string) error { + // Set required headers before signing + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + // Calculate SHA256 hash of the request body + var bodyHash string + if req.Body != nil { + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("failed to read request body: %v", err) + } + // Restore the body for subsequent reads + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + hash := sha256.Sum256(bodyBytes) + bodyHash = hex.EncodeToString(hash[:]) + } else { + // For empty body, use the hash of an empty string + hash := sha256.Sum256([]byte{}) + bodyHash = hex.EncodeToString(hash[:]) + } + + cfg, err := config.LoadDefaultConfig(context.TODO(), + config.WithRegion(region), + config.WithCredentialsProvider(aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) { + creds := aws.Credentials{ + AccessKeyID: accessKey, + SecretAccessKey: secretKey, + } + if sessionToken != nil { + creds.SessionToken = *sessionToken + } + return creds, nil + })), + ) + if err != nil { + return fmt.Errorf("failed to load AWS config: %v", err) + } + + // Create the AWS signer + signer := v4.NewSigner() + + // Get credentials + creds, err := cfg.Credentials.Retrieve(context.TODO()) + if err != nil { + return fmt.Errorf("failed to retrieve credentials: %v", err) + } + + // Sign the request with AWS Signature V4 + if err := signer.SignHTTP(context.TODO(), creds, req, bodyHash, service, region, time.Now()); err != nil { + return fmt.Errorf("failed to sign request: %v", err) + } + + return nil +} diff --git a/tests/account.go b/tests/account.go index 81aa72e34..174fee9b7 100644 --- a/tests/account.go +++ b/tests/account.go @@ -1,119 +1,110 @@ package tests import ( - "bifrost/interfaces" - "bifrost/providers" "fmt" -) - -// BaseAccount provides a basic implementation of the Account interface -type BaseAccount struct { - providers map[string]interfaces.Provider - keys map[string][]interfaces.Key - config map[string]interfaces.ConcurrencyAndBufferSize -} - -type ProviderConfig map[string]struct { - Keys []interfaces.Key `json:"keys"` - ConcurrencyConfig interfaces.ConcurrencyAndBufferSize `json:"concurrency_config"` -} - -func (baseAccount *BaseAccount) Init(config ProviderConfig) error { - baseAccount.providers = make(map[string]interfaces.Provider) - baseAccount.keys = make(map[string][]interfaces.Key) - baseAccount.config = make(map[string]interfaces.ConcurrencyAndBufferSize) - - for providerKey, providerData := range config { - // Create provider instance based on the key - provider, err := baseAccount.createProvider(providerKey, providerData.Keys) - if err != nil { - return fmt.Errorf("failed to create provider %s: %v", providerKey, err) - } - - fmt.Println("✅ provider created") - - // Add provider to the account - baseAccount.AddProvider(provider) + "os" - // Add keys for the provider - for _, keyData := range providerData.Keys { - key := interfaces.Key{ - Value: keyData.Value, - Models: keyData.Models, - Weight: keyData.Weight, - } + "github.com/maximhq/bifrost/interfaces" - baseAccount.AddKey(providerKey, key) - } + "github.com/maximhq/maxim-go" +) - // Set provider configuration - baseAccount.SetProviderConcurrencyConfig(providerKey, providerData.ConcurrencyConfig) - } +// BaseAccount provides a basic implementation of the Account interface for Anthropic and OpenAI providers +type BaseAccount struct{} - return nil +// GetInitiallyConfiguredProviderKeys returns all provider keys +func (baseAccount *BaseAccount) GetInitiallyConfiguredProviders() ([]interfaces.SupportedModelProvider, error) { + return []interfaces.SupportedModelProvider{interfaces.OpenAI, interfaces.Anthropic, interfaces.Bedrock}, nil } -// createProvider creates a new provider instance based on the provider key -func (ba *BaseAccount) createProvider(providerKey string, keys []interfaces.Key) (interfaces.Provider, error) { - if len(keys) == 0 { - return nil, fmt.Errorf("no keys found for provider: %s", providerKey) - } - - switch interfaces.SupportedModelProvider(providerKey) { +// GetKeysForProvider returns all keys associated with a provider +func (baseAccount *BaseAccount) GetKeysForProvider(providerKey interfaces.SupportedModelProvider) ([]interfaces.Key, error) { + switch providerKey { case interfaces.OpenAI: - return providers.NewOpenAIProvider(), nil + return []interfaces.Key{ + { + Value: os.Getenv("OPEN_AI_API_KEY"), + Models: []string{"gpt-4o-mini"}, + Weight: 1.0, + }, + }, nil case interfaces.Anthropic: - return providers.NewAnthropicProvider(), nil + return []interfaces.Key{ + { + Value: os.Getenv("ANTHROPIC_API_KEY"), + Models: []string{"claude-3-7-sonnet-20250219", "claude-2.1"}, + Weight: 1.0, + }, + }, nil + case interfaces.Bedrock: + return []interfaces.Key{ + { + Value: os.Getenv("BEDROCK_API_KEY"), + Models: []string{"anthropic.claude-v2:1", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"}, + Weight: 1.0, + }, + }, nil + case interfaces.Cohere: + return []interfaces.Key{ + { + Value: os.Getenv("COHERE_API_KEY"), + Models: []string{"command-a-03-2025"}, + Weight: 1.0, + }, + }, nil default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } } -// GetInitiallyConfiguredProviders returns all configured providers -func (ba *BaseAccount) GetInitiallyConfiguredProviders() ([]interfaces.Provider, error) { - providers := make([]interfaces.Provider, 0, len(ba.providers)) - for _, provider := range ba.providers { - providers = append(providers, provider) - } - - return providers, nil -} - -// GetKeysForProvider returns all keys associated with a provider -func (ba *BaseAccount) GetKeysForProvider(provider interfaces.Provider) ([]interfaces.Key, error) { - providerKey := string(provider.GetProviderKey()) - if keys, exists := ba.keys[providerKey]; exists { - return keys, nil - } - - return nil, fmt.Errorf("no keys found for provider: %s", providerKey) -} - // GetConcurrencyAndBufferSizeForProvider returns the concurrency and buffer size settings for a provider -func (ba *BaseAccount) GetConcurrencyAndBufferSizeForProvider(provider interfaces.Provider) (interfaces.ConcurrencyAndBufferSize, error) { - providerKey := string(provider.GetProviderKey()) - if config, exists := ba.config[providerKey]; exists { - return config, nil +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.SupportedModelProvider) (*interfaces.ProviderConfig, error) { + switch providerKey { + case interfaces.OpenAI: + return &interfaces.ProviderConfig{ + NetworkConfig: interfaces.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case interfaces.Anthropic: + return &interfaces.ProviderConfig{ + NetworkConfig: interfaces.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case interfaces.Bedrock: + return &interfaces.ProviderConfig{ + NetworkConfig: interfaces.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + MetaConfig: &interfaces.MetaConfig{ + SecretAccessKey: maxim.StrPtr(os.Getenv("BEDROCK_ACCESS_KEY")), + Region: maxim.StrPtr("us-east-1"), + }, + ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case interfaces.Cohere: + return &interfaces.ProviderConfig{ + NetworkConfig: interfaces.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) } - - // Default values if not configured - return interfaces.ConcurrencyAndBufferSize{ - Concurrency: 5, - BufferSize: 100, - }, nil -} - -// AddProvider adds a new provider to the account -func (ba *BaseAccount) AddProvider(provider interfaces.Provider) { - ba.providers[string(provider.GetProviderKey())] = provider -} - -// AddKey adds a new key for a provider -func (ba *BaseAccount) AddKey(providerKey string, key interfaces.Key) { - ba.keys[providerKey] = append(ba.keys[providerKey], key) -} - -// SetProviderConfig sets the concurrency and buffer size for a provider -func (ba *BaseAccount) SetProviderConcurrencyConfig(providerKey string, config interfaces.ConcurrencyAndBufferSize) { - ba.config[providerKey] = config } diff --git a/tests/anthropic_test.go b/tests/anthropic_test.go index 952a0c16e..0fa2eb793 100644 --- a/tests/anthropic_test.go +++ b/tests/anthropic_test.go @@ -1,41 +1,48 @@ package tests import ( - "bifrost" - "bifrost/interfaces" + "context" "fmt" "testing" "time" + + "github.com/maximhq/bifrost" + "github.com/maximhq/bifrost/interfaces" ) // setupAnthropicRequests sends multiple test requests to Anthropic func setupAnthropicRequests(bifrost *bifrost.Bifrost) { - anthropicMessages := []string{ - "What's your favorite programming language?", - "Can you help me write a Go function?", - "What's the best way to learn programming?", - "Tell me about artificial intelligence.", + ctx := context.Background() + + maxTokens := 4096 + + params := interfaces.ModelParameters{ + MaxTokens: &maxTokens, } + // Text completion request go func() { - config := interfaces.ModelParameters{ - ExtraParams: map[string]interface{}{ - "max_tokens_to_sample": 4096, - }, - } + text := "Hello world!" - result, err := bifrost.TextCompletionRequest(interfaces.Anthropic, "claude-2.1", "Hello world!", &config) + result, err := bifrost.TextCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{ + Model: "claude-2.1", + Input: interfaces.RequestInput{ + TextInput: &text, + }, + Params: ¶ms, + }, ctx) if err != nil { fmt.Println("Error:", err) } else { - fmt.Println("🤖 Text Completion Result:", result.Choices[0].Message.Content) + fmt.Println("🤖 Text Completion Result:", *result.Choices[0].Message.Content) } }() - config := interfaces.ModelParameters{ - ExtraParams: map[string]interface{}{ - "max_tokens": 4096, - }, + // Regular chat completion requests + anthropicMessages := []string{ + "Hello! How are you today?", + "Tell me a joke!", + "What's your favorite programming language?", } for i, message := range anthropicMessages { @@ -44,15 +51,86 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { time.Sleep(delay) messages := []interfaces.Message{ { - Role: interfaces.UserRole, + Role: interfaces.RoleUser, Content: &msg, }, } - result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, "claude-3-7-sonnet-20250219", messages, &config) + result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{ + Model: "claude-3-7-sonnet-20250219", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + if err != nil { fmt.Printf("Error in Anthropic request %d: %v\n", index+1, err) } else { - fmt.Printf("🤖 Chat Completion Result %d: %s\n", index+1, result.Choices[0].Message.Content) + fmt.Printf("🤖 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content) + } + }(message, delay, i) + } + + // Tool calls test + setupAnthropicToolCalls(bifrost, ctx) +} + +// setupAnthropicToolCalls tests Anthropic's function calling capability +func setupAnthropicToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { + anthropicMessages := []string{ + "What's the weather like in Mumbai?", + } + + maxTokens := 4096 + + params := interfaces.ModelParameters{ + Tools: &[]interfaces.Tool{{ + Type: "function", + Function: interfaces.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: interfaces.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }}, + MaxTokens: &maxTokens, + } + + for i, message := range anthropicMessages { + delay := time.Duration(500+100*i) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{ + Model: "claude-3-7-sonnet-20250219", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + + if err != nil { + fmt.Printf("Error in Anthropic tool call request %d: %v\n", index+1, err) + } else { + toolCall := *result.Choices[1].Message.ToolCalls + fmt.Printf("🤖 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments) } }(message, delay, i) } diff --git a/tests/bedrock_test.go b/tests/bedrock_test.go new file mode 100644 index 000000000..9c1791931 --- /dev/null +++ b/tests/bedrock_test.go @@ -0,0 +1,154 @@ +package tests + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/maximhq/bifrost" + "github.com/maximhq/bifrost/interfaces" +) + +// setupBedrockRequests sends multiple test requests to Bedrock +func setupBedrockRequests(bifrost *bifrost.Bifrost) { + ctx := context.Background() + + maxTokens := 4096 + + params := interfaces.ModelParameters{ + MaxTokens: &maxTokens, + } + + // Text completion request + go func() { + text := "\n\nHuman:\n\nAssistant:" + + result, err := bifrost.TextCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{ + Model: "anthropic.claude-v2:1", + Input: interfaces.RequestInput{ + TextInput: &text, + }, + Params: ¶ms, + }, ctx) + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Println("🤖 Text Completion Result:", *result.Choices[0].Message.Content) + } + }() + + // Regular chat completion requests + bedrockMessages := []string{ + "Hello! How are you today?", + "Tell me a joke!", + "What's your favorite programming language?", + } + + for i, message := range bedrockMessages { + delay := time.Duration(500+100*i) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{ + Model: "anthropic.claude-3-sonnet-20240229-v1:0", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + + if err != nil { + fmt.Printf("Error in Bedrock request %d: %v\n", index+1, err) + } else { + fmt.Printf("🤖 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content) + } + }(message, delay, i) + } + + // Tool calls test + setupBedrockToolCalls(bifrost, ctx) +} + +// setupBedrockToolCalls tests Bedrock's function calling capability +func setupBedrockToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { + bedrockMessages := []string{ + "What's the weather like in Mumbai?", + } + + maxTokens := 4096 + + params := interfaces.ModelParameters{ + Tools: &[]interfaces.Tool{{ + Type: "function", + Function: interfaces.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: interfaces.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }}, + MaxTokens: &maxTokens, + } + + for i, message := range bedrockMessages { + delay := time.Duration(500+100*i) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{ + Model: "anthropic.claude-3-sonnet-20240229-v1:0", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + + if err != nil { + fmt.Printf("Error in Bedrock tool call request %d: %v\n", index+1, err) + } else { + if result.Choices[0].Message.ToolCalls != nil && len(*result.Choices[0].Message.ToolCalls) > 0 { + toolCall := *result.Choices[0].Message.ToolCalls + fmt.Printf("🤖 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments) + } else { + fmt.Printf("🤖 No tool calls in response %d\n", index+1) + fmt.Println("Raw JSON Response", result.ExtraFields.RawResponse) + } + } + }(message, delay, i) + } +} + +func TestBedrock(t *testing.T) { + bifrost, err := getBifrost() + if err != nil { + t.Fatalf("Error initializing bifrost: %v", err) + return + } + + setupBedrockRequests(bifrost) + + bifrost.Cleanup() +} diff --git a/tests/cohere_test.go b/tests/cohere_test.go new file mode 100644 index 000000000..750d59399 --- /dev/null +++ b/tests/cohere_test.go @@ -0,0 +1,137 @@ +package tests + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/maximhq/bifrost" + "github.com/maximhq/bifrost/interfaces" +) + +// setupCohereRequests sends multiple test requests to Cohere +func setupCohereRequests(bifrost *bifrost.Bifrost) { + text := "Hello world!" + ctx := context.Background() + + // Text completion request + go func() { + result, err := bifrost.TextCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{ + Model: "command-a-03-2025", + Input: interfaces.RequestInput{ + TextInput: &text, + }, + Params: nil, + }, ctx) + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Println("🐒 Text Completion Result:", result.Choices[0].Message.Content) + } + }() + + // Regular chat completion requests + cohereMessages := []string{ + "Hello! How are you today?", + "Tell me a joke!", + "What's your favorite programming language?", + } + + for i, message := range cohereMessages { + delay := time.Duration(100*(i+1)) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{ + Model: "command-a-03-2025", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: nil, + }, ctx) + if err != nil { + fmt.Printf("Error in Cohere request %d: %v\n", index+1, err) + } else { + fmt.Printf("🐒 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content) + } + }(message, delay, i) + } + + // Tool calls test + setupCohereToolCalls(bifrost, ctx) +} + +// setupCohereToolCalls tests Cohere's function calling capability +func setupCohereToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { + cohereMessages := []string{ + "What's the weather like in Mumbai?", + } + + params := interfaces.ModelParameters{ + Tools: &[]interfaces.Tool{{ + Type: "function", + Function: interfaces.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: interfaces.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }}, + } + + for i, message := range cohereMessages { + delay := time.Duration(100*(i+1)) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{ + Model: "command-a-03-2025", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + if err != nil { + fmt.Printf("Error in Cohere tool call request %d: %v\n", index+1, err) + } else { + toolCall := *result.Choices[0].Message.ToolCalls + fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments) + } + }(message, delay, i) + } +} + +func TestCohere(t *testing.T) { + bifrost, err := getBifrost() + if err != nil { + t.Fatalf("Error initializing bifrost: %v", err) + return + } + + setupCohereRequests(bifrost) + + bifrost.Cleanup() +} diff --git a/tests/openai_test.go b/tests/openai_test.go index 1e82a31d5..dd73c6855 100644 --- a/tests/openai_test.go +++ b/tests/openai_test.go @@ -1,18 +1,29 @@ package tests import ( - "bifrost" - "bifrost/interfaces" + "context" "fmt" "testing" "time" + + "github.com/maximhq/bifrost" + "github.com/maximhq/bifrost/interfaces" ) // setupOpenAIRequests sends multiple test requests to OpenAI func setupOpenAIRequests(bifrost *bifrost.Bifrost) { + text := "Hello world!" + ctx := context.Background() + // Text completion request go func() { - result, err := bifrost.TextCompletionRequest(interfaces.OpenAI, "gpt-4o-mini", "Hello world!", nil) + result, err := bifrost.TextCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ + Model: "gpt-4o-mini", + Input: interfaces.RequestInput{ + TextInput: &text, + }, + Params: nil, + }, ctx) if err != nil { fmt.Println("Error:", err) } else { @@ -20,10 +31,9 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { } }() - // Chat completion requests with different messages and delays + // Regular chat completion requests openAIMessages := []string{ "Hello! How are you today?", - "What's the weather like?", "Tell me a joke!", "What's your favorite programming language?", } @@ -34,15 +44,81 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { time.Sleep(delay) messages := []interfaces.Message{ { - Role: interfaces.UserRole, + Role: interfaces.RoleUser, Content: &msg, }, } - result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, "gpt-4o-mini", messages, nil) + result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ + Model: "gpt-4o-mini", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: nil, + }, ctx) if err != nil { fmt.Printf("Error in OpenAI request %d: %v\n", index+1, err) } else { - fmt.Printf("🐒 Chat Completion Result %d: %s\n", index+1, result.Choices[0].Message.Content) + fmt.Printf("🐒 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content) + } + }(message, delay, i) + } + + // Tool calls test + setupOpenAIToolCalls(bifrost, ctx) +} + +// setupOpenAIToolCalls tests OpenAI's function calling capability +func setupOpenAIToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { + openAIMessages := []string{ + "What's the weather like in Mumbai?", + } + + params := interfaces.ModelParameters{ + Tools: &[]interfaces.Tool{{ + Type: "function", + Function: interfaces.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: interfaces.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }}, + } + + for i, message := range openAIMessages { + delay := time.Duration(100*(i+1)) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ + Model: "gpt-4o-mini", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + if err != nil { + fmt.Printf("Error in OpenAI tool call request %d: %v\n", index+1, err) + } else { + toolCall := *result.Choices[0].Message.ToolCalls + fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments) } }(message, delay, i) } diff --git a/tests/plugin.go b/tests/plugin.go new file mode 100644 index 000000000..12395435a --- /dev/null +++ b/tests/plugin.go @@ -0,0 +1,56 @@ +package tests + +import ( + "context" + "fmt" + "time" + + "github.com/maximhq/bifrost/interfaces" + + "github.com/maximhq/maxim-go" + "github.com/maximhq/maxim-go/logging" +) + +// Define a custom type for context key to avoid collisions +type contextKey string + +const ( + traceIDKey contextKey = "traceID" +) + +type Plugin struct { + logger *logging.Logger +} + +func (plugin *Plugin) PreHook(ctx *context.Context, req *interfaces.BifrostRequest) (*interfaces.BifrostRequest, error) { + traceID := time.Now().Format("20060102_150405000") + + trace := plugin.logger.Trace(&logging.TraceConfig{ + Id: traceID, + Name: maxim.StrPtr("bifrost"), + }) + + trace.SetInput(fmt.Sprintf("New Request Incoming: %v", req)) + + if ctx != nil { + // Store traceID in context + *ctx = context.WithValue(*ctx, traceIDKey, traceID) + } + + return req, nil +} + +func (plugin *Plugin) PostHook(ctxRef *context.Context, res *interfaces.BifrostResponse) (*interfaces.BifrostResponse, error) { + // Get traceID from context + if ctxRef != nil { + ctx := *ctxRef + traceID, ok := ctx.Value(traceIDKey).(string) + if !ok { + return res, fmt.Errorf("traceID not found in context") + } + + plugin.logger.SetTraceOutput(traceID, fmt.Sprintf("Response: %v", res)) + } + + return res, nil +} diff --git a/tests/setup.go b/tests/setup.go index efa82723f..812a6ccf3 100644 --- a/tests/setup.go +++ b/tests/setup.go @@ -1,12 +1,16 @@ package tests import ( - "bifrost" - "bifrost/interfaces" + "fmt" "log" "os" + "github.com/maximhq/bifrost" + "github.com/maximhq/bifrost/interfaces" + "github.com/joho/godotenv" + "github.com/maximhq/maxim-go" + "github.com/maximhq/maxim-go/logging" ) func loadEnv() { @@ -16,39 +20,33 @@ func loadEnv() { } } +func getPlugin() (interfaces.Plugin, error) { + loadEnv() + + mx := maxim.Init(&maxim.MaximSDKConfig{ApiKey: os.Getenv("MAXIM_API_KEY")}) + + logger, err := mx.GetLogger(&logging.LoggerConfig{Id: os.Getenv("MAXIM_LOGGER_ID")}) + if err != nil { + return nil, err + } + + plugin := &Plugin{logger} + + return plugin, nil +} + func getBifrost() (*bifrost.Bifrost, error) { loadEnv() account := BaseAccount{} - if err := account.Init( - ProviderConfig{ - "openai": { - Keys: []interfaces.Key{ - {Value: os.Getenv("OPEN_AI_API_KEY"), Weight: 1.0, Models: []string{"gpt-4o-mini"}}, - }, - ConcurrencyConfig: interfaces.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, - "anthropic": { - Keys: []interfaces.Key{ - {Value: os.Getenv("ANTHROPIC_API_KEY"), Weight: 1.0, Models: []string{"claude-3-7-sonnet-20250219", "claude-2.1"}}, - }, - ConcurrencyConfig: interfaces.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, - }, - ); err != nil { - log.Fatal("Error initializing account:", err) + plugin, err := getPlugin() + if err != nil { + fmt.Println("Error setting up the plugin:", err) return nil, err } - bifrost, err := bifrost.Init(&account) + bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}) if err != nil { - log.Fatal("Error initializing bifrost:", err) return nil, err }