diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 000000000..cd9b5cd25
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,30 @@
+{
+ "version": "0.2.0",
+ "configurations": [
+
+ {
+ "name": "Launch Bifrost",
+ "type": "go",
+ "request": "launch",
+ "mode": "debug",
+ "program": "${workspaceFolder}/bifrost.go",
+ "args": []
+ },
+ {
+ "name": "Debug Tests",
+ "type": "go",
+ "request": "launch",
+ "mode": "test",
+ "program": "${workspaceFolder}/tests",
+ "args": []
+ },
+ {
+ "name": "Attach to Process",
+ "type": "go",
+ "request": "attach",
+ "mode": "local",
+ "processId": "${command:pickProcess}"
+ }
+ ]
+ }
+
\ No newline at end of file
diff --git a/bifrost.go b/bifrost.go
index a31e6f6b4..4e8e6368c 100644
--- a/bifrost.go
+++ b/bifrost.go
@@ -1,15 +1,17 @@
package bifrost
import (
- "bifrost/interfaces"
+ "context"
"fmt"
- "log"
"math/rand"
"os"
"os/signal"
"sync"
"syscall"
"time"
+
+ "github.com/maximhq/bifrost/interfaces"
+ "github.com/maximhq/bifrost/providers"
)
type RequestType string
@@ -19,13 +21,9 @@ const (
ChatCompletionRequest RequestType = "chat_completion"
)
-// Request represents a generic request for text or chat completion
-type Request struct {
- Model string
- //* is this okay or should we do string | Message?
- Input interface{}
- Params *interfaces.ModelParameters
- Response chan *interfaces.CompletionResult
+type ChannelMessage struct {
+ interfaces.BifrostRequest
+ Response chan *interfaces.BifrostResponse
Err chan error
Type RequestType
}
@@ -33,31 +31,53 @@ type Request struct {
// Bifrost manages providers and maintains infinite open channels
type Bifrost struct {
account interfaces.Account
- providers []interfaces.Provider // list of processed providers
- requestQueues map[string]chan Request // provider request queues
- wg sync.WaitGroup
+ providers []interfaces.Provider // list of processed providers
+ plugins []interfaces.Plugin
+ requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues
+ waitGroups map[interfaces.SupportedModelProvider]*sync.WaitGroup
+}
+
+func createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) (interfaces.Provider, error) {
+ switch providerKey {
+ case interfaces.OpenAI:
+ return providers.NewOpenAIProvider(config), nil
+ case interfaces.Anthropic:
+ return providers.NewAnthropicProvider(config), nil
+ case interfaces.Bedrock:
+ return providers.NewBedrockProvider(config), nil
+ case interfaces.Cohere:
+ return providers.NewCohereProvider(config), nil
+ default:
+ return nil, fmt.Errorf("unsupported provider: %s", providerKey)
+ }
}
-func (bifrost *Bifrost) prepareProvider(provider interfaces.Provider) error {
- concurrency, err := bifrost.account.GetConcurrencyAndBufferSizeForProvider(provider)
+func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) error {
+ providerConfig, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil {
- log.Fatalf("Failed to get concurrency and buffer size for provider: %v", err)
- return err
+ return fmt.Errorf("failed to get config for provider: %v", err)
}
// Check if the provider has any keys
- keys, err := bifrost.account.GetKeysForProvider(provider)
+ keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil || len(keys) == 0 {
- log.Fatalf("Failed to get keys for provider: %v", err)
- return err
+ return fmt.Errorf("failed to get keys for provider: %v", err)
}
- queue := make(chan Request, concurrency.BufferSize) // Buffered channel per provider
- bifrost.requestQueues[string(provider.GetProviderKey())] = queue
+ queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider
+
+ bifrost.requestQueues[providerKey] = queue
// Start specified number of workers
- for i := 0; i < concurrency.Concurrency; i++ {
- bifrost.wg.Add(1)
+ bifrost.waitGroups[providerKey] = &sync.WaitGroup{}
+
+ provider, err := createProviderFromProviderKey(providerKey, config)
+ if err != nil {
+ return fmt.Errorf("failed to get provider for the given key: %v", err)
+ }
+
+ for i := 0; i < providerConfig.ConcurrencyAndBufferSize.Concurrency; i++ {
+ bifrost.waitGroups[providerKey].Add(1)
go bifrost.processRequests(provider, queue)
}
@@ -65,36 +85,40 @@ func (bifrost *Bifrost) prepareProvider(provider interfaces.Provider) error {
}
// Initializes infinite listening channels for each provider
-func Init(account interfaces.Account) (*Bifrost, error) {
- bifrost := &Bifrost{account: account}
+func Init(account interfaces.Account, plugins []interfaces.Plugin) (*Bifrost, error) {
+ bifrost := &Bifrost{account: account, plugins: plugins}
+ bifrost.waitGroups = make(map[interfaces.SupportedModelProvider]*sync.WaitGroup)
- providers, err := bifrost.account.GetInitiallyConfiguredProviders()
+ providerKeys, err := bifrost.account.GetInitiallyConfiguredProviders()
if err != nil {
- log.Fatalf("Failed to get initially configured providers: %v", err)
return nil, err
}
- bifrost.requestQueues = make(map[string]chan Request)
+ bifrost.requestQueues = make(map[interfaces.SupportedModelProvider]chan ChannelMessage)
// Create buffered channels for each provider and start workers
- for _, provider := range providers {
- if err := bifrost.prepareProvider(provider); err != nil {
- log.Fatalf("Failed to prepare provider: %v", err)
- return nil, err
+ for _, providerKey := range providerKeys {
+ config, err := bifrost.account.GetConfigForProvider(providerKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get config for provider: %v", err)
+ }
+
+ if err := bifrost.prepareProvider(providerKey, config); err != nil {
+ fmt.Printf("failed to prepare provider: %v", err)
}
}
return bifrost, nil
}
-func (bifrost *Bifrost) SelectFromProviderKeys(provider interfaces.Provider, model string) (string, error) {
- keys, err := bifrost.account.GetKeysForProvider(provider)
+func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.SupportedModelProvider, model string) (string, error) {
+ keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil {
return "", err
}
if len(keys) == 0 {
- return "", fmt.Errorf("no keys found for provider: %v", provider.GetProviderKey())
+ return "", fmt.Errorf("no keys found for provider: %v", providerKey)
}
// filter out keys which dont support the model
@@ -113,10 +137,10 @@ func (bifrost *Bifrost) SelectFromProviderKeys(provider interfaces.Provider, mod
}
// Create a new random source
- ran := rand.New(rand.NewSource(time.Now().UnixNano()))
+ randomSource := rand.New(rand.NewSource(time.Now().UnixNano()))
// Shuffle keys using the new random number generator
- ran.Shuffle(len(supportedKeys), func(i, j int) {
+ randomSource.Shuffle(len(supportedKeys), func(i, j int) {
supportedKeys[i], supportedKeys[j] = supportedKeys[j], supportedKeys[i]
})
@@ -127,7 +151,7 @@ func (bifrost *Bifrost) SelectFromProviderKeys(provider interfaces.Provider, mod
}
// Generate a random number within total weight
- r := ran.Float64() * totalWeight
+ r := randomSource.Float64() * totalWeight
var cumulative float64
// Select the key based on weighted probability
@@ -142,23 +166,31 @@ func (bifrost *Bifrost) SelectFromProviderKeys(provider interfaces.Provider, mod
return supportedKeys[len(supportedKeys)-1].Value, nil
}
-func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan Request) {
- defer bifrost.wg.Done()
+func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan ChannelMessage) {
+ defer bifrost.waitGroups[provider.GetProviderKey()].Done()
for req := range queue {
- var result *interfaces.CompletionResult
+ var result *interfaces.BifrostResponse
var err error
- key, err := bifrost.SelectFromProviderKeys(provider, req.Model)
+ key, err := bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model)
if err != nil {
req.Err <- err
continue
}
if req.Type == TextCompletionRequest {
- result, err = provider.TextCompletion(req.Model, key, req.Input.(string), req.Params)
+ if req.Input.TextInput == nil {
+ err = fmt.Errorf("text not provided for text completion request")
+ } else {
+ result, err = provider.TextCompletion(req.Model, key, *req.Input.TextInput, req.Params)
+ }
} else if req.Type == ChatCompletionRequest {
- result, err = provider.ChatCompletion(req.Model, key, req.Input.([]interfaces.Message), req.Params)
+ if req.Input.ChatInput == nil {
+ err = fmt.Errorf("chats not provided for chat completion request")
+ } else {
+ result, err = provider.ChatCompletion(req.Model, key, *req.Input.ChatInput, req.Params)
+ }
}
if err != nil {
@@ -171,7 +203,7 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan
fmt.Println("Worker for provider", provider.GetProviderKey(), "exiting...")
}
-func (bifrost *Bifrost) GetProviderFromProviderKey(key interfaces.SupportedModelProvider) (interfaces.Provider, error) {
+func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key interfaces.SupportedModelProvider) (interfaces.Provider, error) {
for _, provider := range bifrost.providers {
if provider.GetProviderKey() == key {
return provider, nil
@@ -181,73 +213,116 @@ func (bifrost *Bifrost) GetProviderFromProviderKey(key interfaces.SupportedModel
return nil, fmt.Errorf("no provider found for key: %s", key)
}
-func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelProvider) (chan Request, error) {
- var queue chan Request
+func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelProvider) (chan ChannelMessage, error) {
+ var queue chan ChannelMessage
var exists bool
- if queue, exists = bifrost.requestQueues[string(providerKey)]; !exists {
- provider, err := bifrost.GetProviderFromProviderKey(providerKey)
+ if queue, exists = bifrost.requestQueues[providerKey]; !exists {
+ config, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to get config for provider: %v", err)
}
- if err := bifrost.prepareProvider(provider); err != nil {
+ if err := bifrost.prepareProvider(providerKey, config); err != nil {
return nil, err
}
- queue = bifrost.requestQueues[string(providerKey)]
+ queue = bifrost.requestQueues[providerKey]
}
return queue, nil
}
-func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, model, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
+func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) {
+ if req == nil {
+ return nil, fmt.Errorf("bifrost request cannot be nil")
+ }
+
queue, err := bifrost.GetProviderQueue(providerKey)
if err != nil {
return nil, err
}
- responseChan := make(chan *interfaces.CompletionResult)
+ responseChan := make(chan *interfaces.BifrostResponse)
errorChan := make(chan error)
- queue <- Request{
- Model: model,
- Input: text,
- Params: params,
- Response: responseChan,
- Err: errorChan,
- Type: TextCompletionRequest,
+ for _, plugin := range bifrost.plugins {
+ req, err = plugin.PreHook(&ctx, req)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if req == nil {
+ return nil, fmt.Errorf("bifrost request after plugin hooks cannot be nil")
+ }
+
+ queue <- ChannelMessage{
+ BifrostRequest: *req,
+ Response: responseChan,
+ Err: errorChan,
+ Type: TextCompletionRequest,
}
select {
case result := <-responseChan:
+ // Run plugins in reverse order
+ for i := len(bifrost.plugins) - 1; i >= 0; i-- {
+ result, err = bifrost.plugins[i].PostHook(&ctx, result)
+
+ if err != nil {
+ return nil, err
+ }
+ }
+
return result, nil
case err := <-errorChan:
return nil, err
}
}
-func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, model string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
+func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) {
+ if req == nil {
+ return nil, fmt.Errorf("bifrost request cannot be nil")
+ }
+
queue, err := bifrost.GetProviderQueue(providerKey)
if err != nil {
return nil, err
}
- responseChan := make(chan *interfaces.CompletionResult)
+ responseChan := make(chan *interfaces.BifrostResponse)
errorChan := make(chan error)
- queue <- Request{
- Model: model,
- Input: messages,
- Params: params,
- Response: responseChan,
- Err: errorChan,
- Type: ChatCompletionRequest,
+ for _, plugin := range bifrost.plugins {
+ req, err = plugin.PreHook(&ctx, req)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if req == nil {
+ return nil, fmt.Errorf("bifrost request after pre plugin hooks cannot be nil")
+ }
+
+ queue <- ChannelMessage{
+ BifrostRequest: *req,
+ Response: responseChan,
+ Err: errorChan,
+ Type: ChatCompletionRequest,
}
- // Wait for response
select {
case result := <-responseChan:
+ // Run plugins in reverse order
+ for i := len(bifrost.plugins) - 1; i >= 0; i-- {
+ result, err = bifrost.plugins[i].PostHook(&ctx, result)
+
+ if err != nil {
+ return nil, err
+ }
+ }
+
return result, nil
case err := <-errorChan:
return nil, err
@@ -256,7 +331,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
// Shutdown gracefully stops all workers when triggered
func (bifrost *Bifrost) Shutdown() {
- fmt.Println("\n[Graceful Shutdown Initiated] Closing all request channels...")
+ fmt.Println("\n[BIFROST] Graceful Shutdown Initiated - Closing all request channels...")
// Close all provider queues to signal workers to stop
for _, queue := range bifrost.requestQueues {
@@ -264,9 +339,9 @@ func (bifrost *Bifrost) Shutdown() {
}
// Wait for all workers to exit
- bifrost.wg.Wait()
-
- fmt.Println("Bifrost has shut down gracefully.")
+ for _, waitGroup := range bifrost.waitGroups {
+ waitGroup.Wait()
+ }
}
// Cleanup handles SIGINT (Ctrl+C) to exit cleanly
diff --git a/go.mod b/go.mod
index 3f9ab680e..ef8d638aa 100644
--- a/go.mod
+++ b/go.mod
@@ -1,5 +1,31 @@
-module bifrost
+module github.com/maximhq/bifrost
-go 1.21.1
+go 1.23.0
-require github.com/joho/godotenv v1.5.1 // indirect
+toolchain go1.24.1
+
+require github.com/joho/godotenv v1.5.1
+
+require (
+ github.com/aws/aws-sdk-go-v2 v1.36.3
+ github.com/aws/aws-sdk-go-v2/config v1.29.11
+ github.com/maximhq/maxim-go v0.1.1
+ github.com/valyala/fasthttp v1.58.0
+)
+
+require (
+ github.com/andybalholm/brotli v1.1.1 // indirect
+ github.com/aws/aws-sdk-go-v2/credentials v1.17.64 // indirect
+ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sso v1.25.2 // indirect
+ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.2 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sts v1.33.17 // indirect
+ github.com/aws/smithy-go v1.22.2 // indirect
+ github.com/klauspost/compress v1.17.11 // indirect
+ github.com/valyala/bytebufferpool v1.0.0 // indirect
+)
diff --git a/go.sum b/go.sum
index d61b19e1a..06db2183d 100644
--- a/go.sum
+++ b/go.sum
@@ -1,2 +1,40 @@
+github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
+github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
+github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
+github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
+github.com/aws/aws-sdk-go-v2/config v1.29.11 h1:/hkJIxaQzFQy0ebFjG5NHmAcLCrvNSuXeHnxLfeCz1Y=
+github.com/aws/aws-sdk-go-v2/config v1.29.11/go.mod h1:OFPRZVQxC4mKqy2Go6Cse/m9NOStAo6YaMvAcTMUROg=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.64 h1:NH4RAQJEXBDQDUudTqMNHdyyEVa5CvMn0tQicqv48jo=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.64/go.mod h1:tUoJfj79lzEcalHDbyNkpnZZTRg/2ayYOK/iYnRfPbo=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY=
+github.com/aws/aws-sdk-go-v2/service/sso v1.25.2 h1:pdgODsAhGo4dvzC3JAG5Ce0PX8kWXrTZGx+jxADD+5E=
+github.com/aws/aws-sdk-go-v2/service/sso v1.25.2/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.2 h1:wK8O+j2dOolmpNVY1EWIbLgxrGCHJKVPm08Hv/u80M8=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.2/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs=
+github.com/aws/aws-sdk-go-v2/service/sts v1.33.17 h1:PZV5W8yk4OtH1JAuhV2PXwwO9v5G5Aoj+eMCn4T+1Kc=
+github.com/aws/aws-sdk-go-v2/service/sts v1.33.17/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4=
+github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ=
+github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
+github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
+github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
+github.com/maximhq/maxim-go v0.1.1 h1:69uUQjjDPmUGcKg/M4/3AO0fbD+70Agt66pH/UCsI5M=
+github.com/maximhq/maxim-go v0.1.1/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw=
+github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
+github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
+github.com/valyala/fasthttp v1.58.0 h1:GGB2dWxSbEprU9j0iMJHgdKYJVDyjrOwF9RE59PbRuE=
+github.com/valyala/fasthttp v1.58.0/go.mod h1:SYXvHHaFp7QZHGKSHmoMipInhrI5StHrhDTYVEjK/Kw=
+github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
+github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
diff --git a/interfaces/account.go b/interfaces/account.go
index c12f1e099..27d88d4b4 100644
--- a/interfaces/account.go
+++ b/interfaces/account.go
@@ -1,18 +1,14 @@
package interfaces
-type ConcurrencyAndBufferSize struct {
- Concurrency int `json:"concurrency"`
- BufferSize int `json:"buffer_size"`
-}
-
type Key struct {
Value string `json:"value"`
Models []string `json:"models"`
Weight float64 `json:"weight"`
}
+// TODO one get config method
type Account interface {
- GetInitiallyConfiguredProviders() ([]Provider, error)
- GetKeysForProvider(provider Provider) ([]Key, error)
- GetConcurrencyAndBufferSizeForProvider(provider Provider) (ConcurrencyAndBufferSize, error)
+ GetInitiallyConfiguredProviders() ([]SupportedModelProvider, error)
+ GetKeysForProvider(providerKey SupportedModelProvider) ([]Key, error)
+ GetConfigForProvider(providerKey SupportedModelProvider) (*ProviderConfig, error)
}
diff --git a/interfaces/bifrost.go b/interfaces/bifrost.go
new file mode 100644
index 000000000..f1d63395e
--- /dev/null
+++ b/interfaces/bifrost.go
@@ -0,0 +1,229 @@
+package interfaces
+
+// ModelChatMessageRole represents the role of a chat message
+type ModelChatMessageRole string
+
+const (
+ RoleAssistant ModelChatMessageRole = "assistant"
+ RoleUser ModelChatMessageRole = "user"
+ RoleSystem ModelChatMessageRole = "system"
+ RoleChatbot ModelChatMessageRole = "chatbot"
+ RoleTool ModelChatMessageRole = "tool"
+)
+
+type SupportedModelProvider string
+
+const (
+ OpenAI SupportedModelProvider = "openai"
+ Azure SupportedModelProvider = "azure"
+ HuggingFace SupportedModelProvider = "huggingface"
+ Anthropic SupportedModelProvider = "anthropic"
+ Google SupportedModelProvider = "google"
+ Groq SupportedModelProvider = "groq"
+ Bedrock SupportedModelProvider = "bedrock"
+ Maxim SupportedModelProvider = "maxim"
+ Cohere SupportedModelProvider = "cohere"
+ Ollama SupportedModelProvider = "ollama"
+ Lmstudio SupportedModelProvider = "lmstudio"
+)
+
+//* Request Structs
+
+type RequestInput struct {
+ TextInput *string
+ ChatInput *[]Message
+}
+
+type BifrostRequest struct {
+ Model string
+ Input RequestInput
+ Params *ModelParameters
+}
+
+// ModelParameters represents the parameters for model requests
+type ModelParameters struct {
+ ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
+ Tools *[]Tool `json:"tools,omitempty"`
+
+ // Common model parameters
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
+ TopK *int `json:"top_k,omitempty"`
+ MaxTokens *int `json:"max_tokens,omitempty"`
+ StopSequences *[]string `json:"stop_sequences,omitempty"`
+ PresencePenalty *float64 `json:"presence_penalty,omitempty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
+ ParallelToolCalls *bool `json:"parallel_tool_calls"`
+
+ // Dynamic parameters
+ ExtraParams map[string]interface{} `json:"-"`
+}
+
+type FunctionParameters struct {
+ Type string `json:"type,"`
+ Required []string `json:"required"`
+ Properties map[string]interface{} `json:"properties"`
+}
+
+// Function represents a function definition for tool calls
+type Function struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters FunctionParameters `json:"parameters"`
+}
+
+// Tool represents a tool that can be used with the model
+type Tool struct {
+ ID *string `json:"id,omitempty"`
+ Type string `json:"type"`
+ Function Function `json:"function"`
+}
+
+// combined tool choices for all providers
+type ToolChoiceType string
+
+const (
+ ToolChoiceNone ToolChoiceType = "none"
+ ToolChoiceAuto ToolChoiceType = "auto"
+ ToolChoiceAny ToolChoiceType = "any"
+ ToolChoiceTool ToolChoiceType = "tool"
+ ToolChoiceRequired ToolChoiceType = "required"
+)
+
+type ToolChoiceFunction struct {
+ Name string `json:"name"`
+}
+
+type ToolChoice struct {
+ Type ToolChoiceType `json:"type"`
+ Function ToolChoiceFunction `json:"function"`
+}
+
+type Message struct {
+ //* strict check for roles
+ Role ModelChatMessageRole `json:"role"`
+ //* need to make sure either content or imagecontent is provided
+ Content *string `json:"content,omitempty"`
+ ImageContent *ImageContent `json:"image_content,omitempty"`
+ ToolCalls *[]Tool `json:"tool_calls,omitempty"`
+}
+
+type ImageContent struct {
+ Type string `json:"type"`
+ URL string `json:"url"`
+ MediaType string `json:"media_type"`
+}
+
+//* Response Structs
+
+// LLMUsage represents token usage information
+type LLMUsage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ TokenDetails *TokenDetails `json:"prompt_tokens_details,omitempty"`
+ CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
+}
+
+type TokenDetails struct {
+ CachedTokens int `json:"cached_tokens,omitempty"`
+ AudioTokens int `json:"audio_tokens,omitempty"`
+}
+
+type CompletionTokensDetails struct {
+ ReasoningTokens int `json:"reasoning_tokens,omitempty"`
+ AudioTokens int `json:"audio_tokens,omitempty"`
+ AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
+ RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
+}
+
+type BilledLLMUsage struct {
+ PromptTokens *float64 `json:"prompt_tokens,omitempty"`
+ CompletionTokens *float64 `json:"completion_tokens,omitempty"`
+ SearchUnits *float64 `json:"search_units,omitempty"`
+ Classifications *float64 `json:"classifications,omitempty"`
+}
+
+type LogProb struct {
+ Bytes []int `json:"bytes,omitempty"`
+ LogProb float64 `json:"logprob"`
+ Token string `json:"token"`
+}
+
+type ContentLogProb struct {
+ Bytes []int `json:"bytes"`
+ LogProb float64 `json:"logprob"`
+ Token string `json:"token"`
+ TopLogProbs []LogProb `json:"top_logprobs"`
+}
+
+type LogProbs struct {
+ Content []ContentLogProb `json:"content"`
+ Refusal []LogProb `json:"refusal"`
+}
+
+type FunctionCall struct {
+ Name *string `json:"name"`
+ Arguments string `json:"arguments"` // stringified json as retured by OpenAI, might not be a valid JSON always
+}
+
+// ToolCall represents a tool call in a message
+type ToolCall struct {
+ Type *string `json:"type,omitempty"`
+ ID *string `json:"id,omitempty"`
+ Function FunctionCall `json:"function"`
+}
+
+type Citation struct {
+ StartIndex int `json:"start_index"`
+ EndIndex int `json:"end_index"`
+ Title string `json:"title"`
+ URL *string `json:"url,omitempty"`
+ Sources *interface{} `json:"sources,omitempty"`
+ Type *string `json:"type,omitempty"`
+}
+
+type Annotation struct {
+ Type string `json:"type"`
+ Citation Citation `json:"url_citation"`
+}
+
+// BifrostResponseChoiceMessage represents a choice in the completion response
+type BifrostResponseChoiceMessage struct {
+ Role ModelChatMessageRole `json:"role"`
+ Content *string `json:"content,omitempty"`
+ Refusal *string `json:"refusal,omitempty"`
+ Annotations []Annotation `json:"annotations,omitempty"`
+ ToolCalls *[]ToolCall `json:"tool_calls,omitempty"`
+}
+
+// BifrostResponseChoice represents a choice in the completion result
+type BifrostResponseChoice struct {
+ Index int `json:"index"`
+ Message BifrostResponseChoiceMessage `json:"message"`
+ FinishReason *string `json:"finish_reason,omitempty"`
+ StopString *string `json:"stop,omitempty"`
+ LogProbs *LogProbs `json:"log_probs,omitempty"`
+}
+
+type BifrostResponseExtraFields struct {
+ Provider SupportedModelProvider `json:"provider"`
+ Params ModelParameters `json:"model_params"`
+ Latency *float64 `json:"latency,omitempty"`
+ ChatHistory *[]BifrostResponseChoiceMessage `json:"chat_history,omitempty"`
+ BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"`
+ RawResponse interface{} `json:"raw_response"`
+}
+
+// BifrostResponse represents the complete result from a model completion
+type BifrostResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"` // text.completion or chat.completion
+ Choices []BifrostResponseChoice `json:"choices"`
+ Model string `json:"model"`
+ Created int `json:"created"` // The Unix timestamp (in seconds).
+ ServiceTier *string `json:"service_tier,omitempty"`
+ SystemFingerprint *string `json:"system_fingerprint,omitempty"`
+ Usage LLMUsage `json:"usage"`
+ ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
+}
diff --git a/interfaces/plugin.go b/interfaces/plugin.go
new file mode 100644
index 000000000..380379f36
--- /dev/null
+++ b/interfaces/plugin.go
@@ -0,0 +1,8 @@
+package interfaces
+
+import "context"
+
+type Plugin interface {
+ PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error)
+ PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error)
+}
diff --git a/interfaces/provider.go b/interfaces/provider.go
index b555e1061..cb07a7596 100644
--- a/interfaces/provider.go
+++ b/interfaces/provider.go
@@ -1,213 +1,33 @@
package interfaces
-// LLMUsage represents token usage information
-type LLMUsage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
- Latency float64 `json:"latency,omitempty"`
-}
-
-// LLMInteractionCost represents cost information for LLM interactions
-type LLMInteractionCost struct {
- Input float64 `json:"input"`
- Output float64 `json:"output"`
- Total float64 `json:"total"`
-}
-
-// Function represents a function definition for tool calls
-type Function struct {
- Name string `json:"name"`
- Description string `json:"description"`
- Parameters interface{} `json:"parameters"`
-}
-
-// Tool represents a tool that can be used with the model
-type Tool struct {
- Type string `json:"type"`
- Function Function `json:"function"`
-}
-
-// ModelParameters represents the parameters for model requests
-type ModelParameters struct {
- TestRunEntryID *string `json:"testRunEntryId,omitempty"`
- PromptTools *[]string `json:"promptTools,omitempty"`
- ToolChoice *string `json:"toolChoice,omitempty"`
- Tools *[]Tool `json:"tools,omitempty"`
- FunctionCall *string `json:"functionCall,omitempty"`
- Functions *[]Function `json:"functions,omitempty"`
- // Dynamic parameters
- ExtraParams map[string]interface{} `json:"-"`
-}
-
-// RequestOptions represents options for model requests
-type RequestOptions struct {
- UseCache bool `json:"useCache,omitempty"`
- WaitForModel bool `json:"waitForModel,omitempty"`
- CompletionType string `json:"CompletionType,omitempty"`
-}
-
-// FunctionCall represents a function call in a tool call
-type FunctionCall struct {
- Name string `json:"name"`
- Arguments string `json:"arguments"`
-}
-
-// ToolCall represents a tool call in a message
-type ToolCall struct {
- Type string `json:"type"`
- ID string `json:"id"`
- Function FunctionCall `json:"function"`
-}
-
-// ModelChatMessageRole represents the role of a chat message
-type ModelChatMessageRole string
-
-const (
- RoleAssistant ModelChatMessageRole = "assistant"
- RoleUser ModelChatMessageRole = "user"
- RoleSystem ModelChatMessageRole = "system"
- RoleModel ModelChatMessageRole = "model"
- RoleChatbot ModelChatMessageRole = "chatbot"
- RoleTool ModelChatMessageRole = "tool"
-)
-
-// CompletionResponseChoice represents a choice in the completion response
-type CompletionResponseChoice struct {
- Role ModelChatMessageRole `json:"role"`
- Content string `json:"content"`
- FunctionCall *FunctionCall `json:"function_call,omitempty"`
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
-}
-
-// CompletionResultChoice represents a choice in the completion result
-type CompletionResultChoice struct {
- Index int `json:"index"`
- Message CompletionResponseChoice `json:"message"`
- FinishReason string `json:"finish_reason,omitempty"`
- LogProbs interface{} `json:"logprobs,omitempty"`
-}
-
-// ToolResult represents the result of a tool call
-type ToolResult struct {
- Role ModelChatMessageRole `json:"role"`
- Content string `json:"content"`
- ToolCallID string `json:"tool_call_id"`
-}
+// TODO third party providers
-// ToolCallResult represents a single tool call result
-type ToolCallResult struct {
- Name string `json:"name"`
- Result interface{} `json:"result"`
- Type string `json:"type"`
- ID string `json:"id"`
+type NetworkConfig struct {
+ DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
}
-// ToolCallResults represents a collection of tool call results
-type ToolCallResults struct {
- Version int `json:"version"`
- Results []ToolCallResult `json:"results"`
+type MetaConfig struct {
+ SecretAccessKey *string `json:"secret_access_key,omitempty"`
+ Region *string `json:"region,omitempty"`
+ SessionToken *string `json:"session_token,omitempty"`
+ ARN *string `json:"arn,omitempty"`
+ InferenceProfiles map[string]string `json:"inference_profiles,omitempty"`
}
-// CompletionResult represents the complete result from a model completion
-type CompletionResult struct {
- Error *struct {
- Code string `json:"code"`
- Message string `json:"message"`
- Type string `json:"type"`
- } `json:"error,omitempty"`
- ID string `json:"id"`
- Choices []CompletionResultChoice `json:"choices"`
- ToolCallResult interface{} `json:"tool_call_result,omitempty"`
- ToolCallResults *ToolCallResults `json:"toolCallResults,omitempty"`
- Provider SupportedModelProvider `json:"provider,omitempty"`
- Usage LLMUsage `json:"usage"`
- Cost *LLMInteractionCost `json:"cost,omitempty"`
- Model string `json:"model,omitempty"`
- Created string `json:"created,omitempty"`
- ModelParams interface{} `json:"modelParams,omitempty"`
- Trace *struct {
- Input interface{} `json:"input"`
- Output interface{} `json:"output,omitempty"`
- } `json:"trace,omitempty"`
- RetrievedContext interface{} `json:"retrievedContext,omitempty"`
- VariableBoundRetrievals map[string]interface{} `json:"variableBoundRetrievals,omitempty"`
-}
-
-type SupportedModelProvider string
-
-const (
- OpenAI SupportedModelProvider = "openai"
- Azure SupportedModelProvider = "azure"
- HuggingFace SupportedModelProvider = "huggingface"
- Anthropic SupportedModelProvider = "anthropic"
- Google SupportedModelProvider = "google"
- Groq SupportedModelProvider = "groq"
- Bedrock SupportedModelProvider = "bedrock"
- Maxim SupportedModelProvider = "maxim"
- Cohere SupportedModelProvider = "cohere"
- Ollama SupportedModelProvider = "ollama"
- Lmstudio SupportedModelProvider = "lmstudio"
-)
-
-type Role string
-
-const (
- UserRole Role = "user"
- AssistantRole Role = "assistant"
- SystemRole Role = "system"
-)
-
-type Message struct {
- //* strict check for roles
- Role Role `json:"role"`
- //* need to make sure either content or imagecontent is provided
- Content *string `json:"content"`
- ImageContent *ImageContent `json:"imageContent"`
+type ConcurrencyAndBufferSize struct {
+ Concurrency int `json:"concurrency"`
+ BufferSize int `json:"buffer_size"`
}
-type ImageContent struct {
- Type string `json:"type"`
- ImageURL struct {
- URL string `json:"url"`
- } `json:"image_url"`
+type ProviderConfig struct {
+ NetworkConfig NetworkConfig `json:"network_config"`
+ MetaConfig *MetaConfig `json:"meta_config,omitempty"`
+ ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"`
}
-// type Content struct {
-// Content *string `json:"content"`
-// ImageContent *ImageContent `json:"imageContent"`
-// }
-
-// func (content *Content) MarshalJSON() ([]byte, error) {
-// if content.Content != nil {
-// return []byte(*content.Content), nil
-// } else if content.ImageContent != nil {
-// return json.Marshal(content.ImageContent)
-// }
-
-// return nil, fmt.Errorf("invalid content")
-// }
-
-// func (content *Content) UnmarshalJSON(val []byte) error {
-// var s any
-// json.Unmarshal(val, &s)
-
-// switch s := s.(type) {
-// case string:
-// content.Content = &s
-// case ImageContent:
-// content.ImageContent = &s
-
-// default:
-// return fmt.Errorf("invalid stop")
-// }
-
-// return nil
-// }
-
// Provider defines the interface for AI model providers
type Provider interface {
GetProviderKey() SupportedModelProvider
- TextCompletion(model, key, text string, params *ModelParameters) (*CompletionResult, error)
- ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*CompletionResult, error)
+ TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, error)
+ ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, error)
}
diff --git a/providers/anthropic.go b/providers/anthropic.go
index c9b874b8c..5920ade65 100644
--- a/providers/anthropic.go
+++ b/providers/anthropic.go
@@ -1,24 +1,67 @@
package providers
import (
- "bifrost/interfaces"
- "bytes"
"encoding/json"
"fmt"
- "io"
- "net/http"
"time"
+
+ "github.com/maximhq/bifrost/interfaces"
+ "github.com/valyala/fasthttp"
+
+ "github.com/maximhq/maxim-go"
)
+type AnthropicToolChoice struct {
+ Type interfaces.ToolChoiceType `json:"type"`
+ Name *string `json:"name"`
+ DisableParallelToolUse *bool `json:"disable_parallel_tool_use"`
+}
+
+type AnthropicTextResponse struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Completion string `json:"completion"`
+ Model string `json:"model"`
+ Usage struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ } `json:"usage"`
+}
+
+type AnthropicChatResponse struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Role string `json:"role"`
+ Content []struct {
+ Type string `json:"type"`
+ Text string `json:"text,omitempty"`
+ Thinking string `json:"thinking,omitempty"`
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Input map[string]interface{} `json:"input"`
+ } `json:"content"`
+ Model string `json:"model"`
+ StopReason string `json:"stop_reason,omitempty"`
+ StopSequence *string `json:"stop_sequence,omitempty"`
+ Usage struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ } `json:"usage"`
+}
+
// AnthropicProvider implements the Provider interface for Anthropic's Claude API
type AnthropicProvider struct {
- client *http.Client
+ client *fasthttp.Client
}
// NewAnthropicProvider creates a new AnthropicProvider instance
-func NewAnthropicProvider() *AnthropicProvider {
+func NewAnthropicProvider(config *interfaces.ProviderConfig) *AnthropicProvider {
return &AnthropicProvider{
- client: &http.Client{Timeout: 30 * time.Second},
+ client: &fasthttp.Client{
+ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
+ WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
+ MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize,
+ },
}
}
@@ -26,9 +69,54 @@ func (provider *AnthropicProvider) GetProviderKey() interfaces.SupportedModelPro
return interfaces.Anthropic
}
+func (provider *AnthropicProvider) PrepareTextCompletionParams(params map[string]interface{}) map[string]interface{} {
+ // Check if there is a key entry for max_tokens
+ if maxTokens, exists := params["max_tokens"]; exists {
+ // Check if max_tokens_to_sample is already present
+ if _, exists := params["max_tokens_to_sample"]; !exists {
+ // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample
+ params["max_tokens_to_sample"] = maxTokens
+ }
+ delete(params, "max_tokens")
+ }
+ return params
+}
+
+func (provider *AnthropicProvider) PrepareToolChoices(params map[string]interface{}) map[string]interface{} {
+ toolChoice, exists := params["tool_choice"]
+ if !exists {
+ return params
+ }
+
+ switch tc := toolChoice.(type) {
+ case interfaces.ToolChoice:
+ anthropicToolChoice := AnthropicToolChoice{
+ Type: tc.Type,
+ Name: &tc.Function.Name,
+ }
+
+ parallelToolCalls, exists := params["parallel_tool_calls"]
+ if !exists {
+ return params
+ }
+
+ switch parallelTC := parallelToolCalls.(type) {
+ case bool:
+ disableParallel := !parallelTC
+ anthropicToolChoice.DisableParallelToolUse = &disableParallel
+
+ delete(params, "parallel_tool_calls")
+ }
+
+ params["tool_choice"] = anthropicToolChoice
+ }
+
+ return params
+}
+
// TextCompletion implements text completion using Anthropic's API
-func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
- preparedParams := PrepareParams(params)
+func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
+ preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params))
// Merge additional parameters
requestBody := MergeConfig(map[string]interface{}{
@@ -43,88 +131,72 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param
}
// Create the request with the JSON body
- req, err := http.NewRequest("POST", "https://api.anthropic.com/v1/complete", bytes.NewBuffer(jsonData))
- if err != nil {
- return nil, fmt.Errorf("error creating request: %v", err)
- }
-
- // Set headers
- req.Header.Set("Content-Type", "application/json")
+ req := fasthttp.AcquireRequest()
+ resp := fasthttp.AcquireResponse()
+ defer fasthttp.ReleaseRequest(req)
+ defer fasthttp.ReleaseResponse(resp)
+
+ req.SetRequestURI("https://api.anthropic.com/v1/complete")
+ req.Header.SetMethod("POST")
+ req.Header.SetContentType("application/json")
req.Header.Set("x-api-key", key)
req.Header.Set("anthropic-version", "2023-06-01")
+ req.SetBody(jsonData)
// Send the request
- resp, err := provider.client.Do(req)
- if err != nil {
+ if err := provider.client.Do(req, resp); err != nil {
return nil, fmt.Errorf("error sending request: %v", err)
}
- defer resp.Body.Close()
- // Read the response body
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("error reading response: %v", err)
+ // Handle error response
+ if resp.StatusCode() != fasthttp.StatusOK {
+ return nil, fmt.Errorf("anthropic error: %s", resp.Body())
}
- // Check for error response
- if resp.StatusCode != http.StatusOK {
- var errorResp struct {
- Type string `json:"type"`
- Error struct {
- Type string `json:"type"`
- Message string `json:"message"`
- } `json:"error"`
- }
- if err := json.Unmarshal(body, &errorResp); err != nil {
- return nil, fmt.Errorf("error response: %s", string(body))
- }
- return nil, fmt.Errorf("anthropic error: %s", errorResp.Error.Message)
- }
+ // Read the response body
+ body := resp.Body()
// Parse the response
- var result struct {
- ID string `json:"id"`
- Type string `json:"type"`
- Completion string `json:"completion"`
- Model string `json:"model"`
- Usage struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- } `json:"usage"`
+ var response AnthropicTextResponse
+ if err := json.Unmarshal(body, &response); err != nil {
+ return nil, fmt.Errorf("error parsing response: %v", err)
}
- if err := json.Unmarshal(body, &result); err != nil {
- return nil, fmt.Errorf("error parsing response: %v", err)
+ // Parse raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("error parsing raw response: %v", err)
}
// Create the completion result
- completionResult := &interfaces.CompletionResult{
- ID: result.ID,
- Choices: []interfaces.CompletionResultChoice{
+ completionResult := &interfaces.BifrostResponse{
+ ID: response.ID,
+ Choices: []interfaces.BifrostResponseChoice{
{
Index: 0,
- Message: interfaces.CompletionResponseChoice{
+ Message: interfaces.BifrostResponseChoiceMessage{
Role: interfaces.RoleAssistant,
- Content: result.Completion,
+ Content: &response.Completion,
},
},
},
Usage: interfaces.LLMUsage{
- PromptTokens: result.Usage.InputTokens,
- CompletionTokens: result.Usage.OutputTokens,
- TotalTokens: result.Usage.InputTokens + result.Usage.OutputTokens,
+ PromptTokens: response.Usage.InputTokens,
+ CompletionTokens: response.Usage.OutputTokens,
+ TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
+ },
+ Model: response.Model,
+ ExtraFields: interfaces.BifrostResponseExtraFields{
+ Provider: interfaces.Anthropic,
+ RawResponse: rawResponse,
},
- Model: result.Model,
- Provider: interfaces.Anthropic,
}
return completionResult, nil
}
// ChatCompletion implements chat completion using Anthropic's API
-func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
- startTime := time.Now()
-
+func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
// Format messages for Anthropic API
var formattedMessages []map[string]interface{}
for _, msg := range messages {
@@ -136,137 +208,125 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []
preparedParams := PrepareParams(params)
+ // Transform tools if present
+ if params != nil && params.Tools != nil && len(*params.Tools) > 0 {
+ var tools []map[string]interface{}
+ for _, tool := range *params.Tools {
+ tools = append(tools, map[string]interface{}{
+ "name": tool.Function.Name,
+ "description": tool.Function.Description,
+ "input_schema": tool.Function.Parameters,
+ })
+ }
+
+ preparedParams["tools"] = tools
+ }
+
// Merge additional parameters
requestBody := MergeConfig(map[string]interface{}{
"model": model,
"messages": formattedMessages,
}, preparedParams)
- jsonData, err := json.Marshal(requestBody)
+ jsonBody, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %v", err)
}
// Create request
- req, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages", bytes.NewBuffer(jsonData))
- if err != nil {
- return nil, fmt.Errorf("error creating request: %v", err)
- }
-
- // Set headers
- req.Header.Set("Content-Type", "application/json")
+ req := fasthttp.AcquireRequest()
+ resp := fasthttp.AcquireResponse()
+ defer fasthttp.ReleaseRequest(req)
+ defer fasthttp.ReleaseResponse(resp)
+
+ req.SetRequestURI("https://api.anthropic.com/v1/messages")
+ req.Header.SetMethod("POST")
+ req.Header.SetContentType("application/json")
req.Header.Set("x-api-key", key)
req.Header.Set("anthropic-version", "2023-06-01")
+ req.SetBody(jsonBody)
- // Send request
- resp, err := provider.client.Do(req)
- if err != nil {
- return nil, fmt.Errorf("error sending request: %v", err)
- }
- defer resp.Body.Close()
-
- // Read response body
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("error reading response: %v", err)
+ // Make request
+ if err := provider.client.Do(req, resp); err != nil {
+ return nil, fmt.Errorf("error making request: %v", err)
}
- // Check for non-200 status codes
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("API error: %s", string(body))
+ // Handle error response
+ if resp.StatusCode() != fasthttp.StatusOK {
+ return nil, fmt.Errorf("anthropic error: %s", resp.Body())
}
- // Calculate latency
- latency := time.Since(startTime).Seconds()
-
- // Decode response
- var anthropicResponse struct {
- ID string `json:"id"`
- Type string `json:"type"`
- Role string `json:"role"`
- Content []struct {
- Type string `json:"type"`
- Text string `json:"text,omitempty"`
- Thinking string `json:"thinking,omitempty"`
- ToolUse *struct {
- ID string `json:"id"`
- Name string `json:"name"`
- Input map[string]interface{} `json:"input"`
- } `json:"tool_use,omitempty"`
- } `json:"content"`
- Model string `json:"model"`
- StopReason string `json:"stop_reason,omitempty"`
- StopSequence *string `json:"stop_sequence,omitempty"`
- Usage struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- } `json:"usage"`
+ // Decode structured response
+ var response AnthropicChatResponse
+ if err := json.Unmarshal(resp.Body(), &response); err != nil {
+ return nil, fmt.Errorf("error decoding structured response: %v", err)
}
- if err := json.Unmarshal(body, &anthropicResponse); err != nil {
- return nil, fmt.Errorf("error decoding response: %v", err)
+ // Decode raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(resp.Body(), &rawResponse); err != nil {
+ return nil, fmt.Errorf("error decoding raw response: %v", err)
}
- // Process the response into our CompletionResult format
- var content string
- var toolCalls []interfaces.ToolCall
- var finishReason string
+ // Process the response into our BifrostResponse format
+ var choices []interfaces.BifrostResponseChoice
// Process content and tool calls
- for _, c := range anthropicResponse.Content {
+ for i, c := range response.Content {
+ var content string
+ var toolCalls []interfaces.ToolCall
+
switch c.Type {
case "thinking":
- if content == "" {
- content = fmt.Sprintf("\n%s\n\n\n", c.Thinking)
- }
+ content = c.Thinking
case "text":
- content += c.Text
+ content = c.Text
case "tool_use":
- if c.ToolUse != nil {
- toolCalls = append(toolCalls, interfaces.ToolCall{
- Type: "function",
- ID: c.ToolUse.ID,
- Function: interfaces.FunctionCall{
- Name: c.ToolUse.Name,
- Arguments: string(must(json.Marshal(c.ToolUse.Input))),
- },
- })
- finishReason = "tool_calls"
+ function := interfaces.FunctionCall{
+ Name: &c.Name,
}
+
+ args, err := json.Marshal(c.Input)
+ if err != nil {
+ function.Arguments = fmt.Sprintf("%v", c.Input)
+ } else {
+ function.Arguments = string(args)
+ }
+
+ toolCalls = append(toolCalls, interfaces.ToolCall{
+ Type: maxim.StrPtr("function"),
+ ID: &c.ID,
+ Function: function,
+ })
}
+
+ choices = append(choices, interfaces.BifrostResponseChoice{
+ Index: i,
+ Message: interfaces.BifrostResponseChoiceMessage{
+ Role: interfaces.RoleAssistant,
+ Content: &content,
+ ToolCalls: &toolCalls,
+ },
+ FinishReason: &response.StopReason,
+ StopString: response.StopSequence,
+ })
}
// Create the completion result
- result := &interfaces.CompletionResult{
- ID: anthropicResponse.ID,
- Choices: []interfaces.CompletionResultChoice{
- {
- Index: 0,
- Message: interfaces.CompletionResponseChoice{
- Role: interfaces.RoleAssistant,
- Content: content,
- ToolCalls: toolCalls,
- },
- FinishReason: finishReason,
- },
- },
+ result := &interfaces.BifrostResponse{
+ ID: response.ID,
+ Choices: choices,
Usage: interfaces.LLMUsage{
- PromptTokens: anthropicResponse.Usage.InputTokens,
- CompletionTokens: anthropicResponse.Usage.OutputTokens,
- TotalTokens: anthropicResponse.Usage.InputTokens + anthropicResponse.Usage.OutputTokens,
- Latency: latency,
+ PromptTokens: response.Usage.InputTokens,
+ CompletionTokens: response.Usage.OutputTokens,
+ TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
+ },
+ Model: response.Model,
+ ExtraFields: interfaces.BifrostResponseExtraFields{
+ Provider: interfaces.Anthropic,
+ RawResponse: rawResponse,
},
- Model: anthropicResponse.Model,
- Provider: interfaces.Anthropic,
}
return result, nil
}
-
-// Helper function to handle JSON marshaling errors
-func must[T any](v T, err error) T {
- if err != nil {
- panic(err)
- }
- return v
-}
diff --git a/providers/bedrock.go b/providers/bedrock.go
new file mode 100644
index 000000000..f49a59f4f
--- /dev/null
+++ b/providers/bedrock.go
@@ -0,0 +1,526 @@
+package providers
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/maximhq/bifrost/interfaces"
+)
+
+type BedrockAnthropicTextResponse struct {
+ Completion string `json:"completion"`
+ StopReason string `json:"stop_reason"`
+ Stop string `json:"stop"`
+}
+
+type BedrockMistralTextResponse struct {
+ Outputs []struct {
+ Text string `json:"text"`
+ StopReason string `json:"stop_reason"`
+ } `json:"outputs"`
+}
+
+type BedrockChatResponse struct {
+ Metrics struct {
+ Latency int `json:"latencyMs"`
+ } `json:"metrics"`
+ Output struct {
+ Message struct {
+ Content []struct {
+ Text string `json:"text"`
+ } `json:"content"`
+ Role string `json:"role"`
+ } `json:"message"`
+ } `json:"output"`
+ StopReason string `json:"stopReason"`
+ Usage struct {
+ InputTokens int `json:"inputTokens"`
+ OutputTokens int `json:"outputTokens"`
+ TotalTokens int `json:"totalTokens"`
+ } `json:"usage"`
+}
+
+type BedrockAnthropicSystemMessage struct {
+ Text string `json:"text"`
+}
+
+type BedrockAnthropicTextMessage struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+}
+
+type BedrockMistralContent struct {
+ Text string `json:"text"`
+}
+
+type BedrockMistralChatMessage struct {
+ Role interfaces.ModelChatMessageRole `json:"role"`
+ Content []BedrockMistralContent `json:"content"`
+ ToolCalls *[]BedrockMistralToolCall `json:"tool_calls,omitempty"`
+ ToolCallID *string `json:"tool_call_id,omitempty"`
+}
+
+type BedrockAnthropicImageMessage struct {
+ Type string `json:"type"`
+ Source BedrockAnthropicImageSource `json:"source"`
+}
+
+type BedrockAnthropicImageSource struct {
+ Type string `json:"type"`
+ MediaType string `json:"media_type"`
+ Data string `json:"data"`
+}
+
+type BedrockMistralToolCall struct {
+ ID string `json:"id"`
+ Function interfaces.Function `json:"function"`
+}
+
+type BedrockAnthropicToolCall struct {
+ ToolSpec BedrockAnthropicToolSpec `json:"toolSpec"`
+}
+
+type BedrockAnthropicToolSpec struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ InputSchema struct {
+ Json interface{} `json:"json"`
+ } `json:"inputSchema"`
+}
+
+type BedrockProvider struct {
+ client *http.Client
+ meta *interfaces.MetaConfig
+}
+
+func NewBedrockProvider(config *interfaces.ProviderConfig) *BedrockProvider {
+ return &BedrockProvider{
+ client: &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)},
+ meta: config.MetaConfig,
+ }
+}
+
+func (provider *BedrockProvider) GetProviderKey() interfaces.SupportedModelProvider {
+ return interfaces.Bedrock
+}
+
+func (provider *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey string) (*http.Request, error) {
+ if provider.meta == nil {
+ return nil, errors.New("meta config for bedrock is not provided")
+ }
+
+ region := "us-east-1"
+ if provider.meta.Region != nil {
+ region = *provider.meta.Region
+ }
+
+ // Create the request with the JSON body
+ req, err := http.NewRequest("POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, fmt.Errorf("error creating request: %v", err)
+ }
+
+ if err := SignAWSRequest(req, accessKey, *provider.meta.SecretAccessKey, provider.meta.SessionToken, region, "bedrock"); err != nil {
+ return nil, err
+ }
+
+ return req, nil
+}
+
+func (provider *BedrockProvider) GetTextCompletionResult(result []byte, model string) (*interfaces.BifrostResponse, error) {
+ switch model {
+ case "anthropic.claude-instant-v1:2":
+ fallthrough
+ case "anthropic.claude-v2":
+ fallthrough
+ case "anthropic.claude-v2:1":
+ var response BedrockAnthropicTextResponse
+ if err := json.Unmarshal(result, &response); err != nil {
+ return nil, fmt.Errorf("failed to parse Bedrock response: %v", err)
+ }
+
+ return &interfaces.BifrostResponse{
+ Choices: []interfaces.BifrostResponseChoice{
+ {
+ Index: 0,
+ Message: interfaces.BifrostResponseChoiceMessage{
+ Role: interfaces.RoleAssistant,
+ Content: &response.Completion,
+ },
+ FinishReason: &response.StopReason,
+ StopString: &response.Stop,
+ },
+ },
+ Model: model,
+ ExtraFields: interfaces.BifrostResponseExtraFields{
+ Provider: interfaces.Bedrock,
+ },
+ }, nil
+
+ case "mistral.mixtral-8x7b-instruct-v0:1":
+ fallthrough
+ case "mistral.mistral-7b-instruct-v0:2":
+ fallthrough
+ case "mistral.mistral-large-2402-v1:0":
+ fallthrough
+ case "mistral.mistral-large-2407-v1:0":
+ fallthrough
+ case "mistral.mistral-small-2402-v1:0":
+ var response BedrockMistralTextResponse
+ if err := json.Unmarshal(result, &response); err != nil {
+ return nil, fmt.Errorf("failed to parse Bedrock response: %v", err)
+ }
+
+ var choices []interfaces.BifrostResponseChoice
+ for i, output := range response.Outputs {
+ choices = append(choices, interfaces.BifrostResponseChoice{
+ Index: i,
+ Message: interfaces.BifrostResponseChoiceMessage{
+ Role: interfaces.RoleAssistant,
+ Content: &output.Text,
+ },
+ FinishReason: &output.StopReason,
+ })
+ }
+
+ return &interfaces.BifrostResponse{
+ Choices: choices,
+ Model: model,
+ ExtraFields: interfaces.BifrostResponseExtraFields{
+ Provider: interfaces.Bedrock,
+ },
+ }, nil
+ }
+
+ return nil, fmt.Errorf("invalid model choice: %s", model)
+}
+
+func (provider *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Message, model string) (map[string]interface{}, error) {
+ switch model {
+ case "anthropic.claude-instant-v1:2":
+ fallthrough
+ case "anthropic.claude-v2":
+ fallthrough
+ case "anthropic.claude-v2:1":
+ fallthrough
+ case "anthropic.claude-3-sonnet-20240229-v1:0":
+ fallthrough
+ case "anthropic.claude-3-5-sonnet-20240620-v1:0":
+ fallthrough
+ case "anthropic.claude-3-5-sonnet-20241022-v2:0":
+ fallthrough
+ case "anthropic.claude-3-5-haiku-20241022-v1:0":
+ fallthrough
+ case "anthropic.claude-3-opus-20240229-v1:0":
+ fallthrough
+ case "anthropic.claude-3-7-sonnet-20250219-v1:0":
+ // Add system messages if present
+ var systemMessages []BedrockAnthropicSystemMessage
+ for _, msg := range messages {
+ if msg.Role == interfaces.RoleSystem {
+ //TODO handling image inputs here
+ systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{
+ Text: *msg.Content,
+ })
+ }
+ }
+
+ // Format messages for Bedrock API
+ var bedrockMessages []map[string]interface{}
+ for _, msg := range messages {
+ if msg.Role != interfaces.RoleSystem {
+ var content any
+ if msg.Content != nil {
+ content = BedrockAnthropicTextMessage{
+ Type: "text",
+ Text: *msg.Content,
+ }
+ } else if msg.ImageContent != nil {
+ content = BedrockAnthropicImageMessage{
+ Type: "image",
+ Source: BedrockAnthropicImageSource{
+ Type: msg.ImageContent.Type,
+ MediaType: msg.ImageContent.MediaType,
+ Data: msg.ImageContent.URL,
+ },
+ }
+ }
+
+ bedrockMessages = append(bedrockMessages, map[string]interface{}{
+ "role": msg.Role,
+ "content": []interface{}{content},
+ })
+ }
+ }
+
+ body := map[string]interface{}{
+ "messages": bedrockMessages,
+ }
+
+ if len(systemMessages) > 0 {
+ var messages []string
+ for _, message := range systemMessages {
+ messages = append(messages, message.Text)
+ }
+
+ body["system"] = strings.Join(messages, " ")
+ }
+
+ return body, nil
+
+ case "mistral.mistral-large-2402-v1:0":
+ fallthrough
+ case "mistral.mistral-large-2407-v1:0":
+ var bedrockMessages []BedrockMistralChatMessage
+ for _, msg := range messages {
+ var filteredToolCalls []BedrockMistralToolCall
+ if msg.ToolCalls != nil {
+ for _, toolCall := range *msg.ToolCalls {
+ filteredToolCalls = append(filteredToolCalls, BedrockMistralToolCall{
+ ID: *toolCall.ID,
+ Function: toolCall.Function,
+ })
+ }
+ }
+
+ message := BedrockMistralChatMessage{
+ Role: msg.Role,
+ Content: []BedrockMistralContent{
+ {Text: *msg.Content},
+ },
+ }
+
+ if len(filteredToolCalls) > 0 {
+ message.ToolCalls = &filteredToolCalls
+ }
+
+ bedrockMessages = append(bedrockMessages, message)
+ }
+
+ body := map[string]interface{}{
+ "messages": bedrockMessages,
+ }
+
+ return body, nil
+ }
+
+ return nil, fmt.Errorf("invalid model choice: %s", model)
+}
+
+func (provider *BedrockProvider) GetChatCompletionTools(params *interfaces.ModelParameters, model string) []BedrockAnthropicToolCall {
+ var tools []BedrockAnthropicToolCall
+
+ switch model {
+ case "anthropic.claude-instant-v1:2":
+ fallthrough
+ case "anthropic.claude-v2":
+ fallthrough
+ case "anthropic.claude-v2:1":
+ fallthrough
+ case "anthropic.claude-3-sonnet-20240229-v1:0":
+ fallthrough
+ case "anthropic.claude-3-5-sonnet-20240620-v1:0":
+ fallthrough
+ case "anthropic.claude-3-5-sonnet-20241022-v2:0":
+ fallthrough
+ case "anthropic.claude-3-5-haiku-20241022-v1:0":
+ fallthrough
+ case "anthropic.claude-3-opus-20240229-v1:0":
+ fallthrough
+ case "anthropic.claude-3-7-sonnet-20250219-v1:0":
+ for _, tool := range *params.Tools {
+ tools = append(tools, BedrockAnthropicToolCall{
+ ToolSpec: BedrockAnthropicToolSpec{
+ Name: tool.Function.Name,
+ Description: tool.Function.Description,
+ InputSchema: struct {
+ Json interface{} `json:"json"`
+ }{
+ Json: tool.Function.Parameters,
+ },
+ },
+ })
+ }
+ }
+
+ return tools
+}
+
+func (provider *BedrockProvider) PrepareTextCompletionParams(params map[string]interface{}, model string) map[string]interface{} {
+ switch model {
+ case "anthropic.claude-instant-v1:2":
+ fallthrough
+ case "anthropic.claude-v2":
+ fallthrough
+ case "anthropic.claude-v2:1":
+ // Check if there is a key entry for max_tokens
+ if maxTokens, exists := params["max_tokens"]; exists {
+ // Check if max_tokens_to_sample is already present
+ if _, exists := params["max_tokens_to_sample"]; !exists {
+ // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample
+ params["max_tokens_to_sample"] = maxTokens
+ }
+ delete(params, "max_tokens")
+ }
+ }
+ return params
+}
+
+func (provider *BedrockProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
+ preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params), model)
+
+ requestBody := MergeConfig(map[string]interface{}{
+ "prompt": text,
+ }, preparedParams)
+
+ // Marshal the request body
+ jsonData, err := json.Marshal(requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("error marshaling request: %v", err)
+ }
+
+ // Create the signed request with correct operation name
+ req, err := provider.PrepareReq(fmt.Sprintf("%s/invoke", model), jsonData, key)
+ if err != nil {
+ return nil, fmt.Errorf("error creating request: %v", err)
+ }
+
+ // Execute the request
+ resp, err := provider.client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to execute request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // Read response body
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %v", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("bedrock API error: %s", string(body))
+ }
+
+ result, err := provider.GetTextCompletionResult(body, model)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse response body: %v", err)
+ }
+
+ // Parse raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("failed to parse raw response: %v", err)
+ }
+
+ result.ExtraFields.RawResponse = rawResponse
+
+ return result, nil
+}
+
+func (provider *BedrockProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
+ messageBody, err := provider.PrepareChatCompletionMessages(messages, model)
+ if err != nil {
+ return nil, fmt.Errorf("error preparing messages: %v", err)
+ }
+
+ preparedParams := PrepareParams(params)
+
+ // Transform tools if present
+ if params != nil && params.Tools != nil && len(*params.Tools) > 0 {
+ preparedParams["tools"] = provider.GetChatCompletionTools(params, model)
+ }
+
+ requestBody := MergeConfig(messageBody, preparedParams)
+
+ // Marshal the request body
+ jsonData, err := json.Marshal(requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("error marshaling request: %v", err)
+ }
+
+ // Format the path with proper model identifier
+ path := fmt.Sprintf("%s/converse", model)
+
+ if provider.meta != nil && provider.meta.InferenceProfiles != nil {
+ if inferenceProfileId, ok := provider.meta.InferenceProfiles[model]; ok {
+ if provider.meta.ARN != nil {
+ encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.ARN, inferenceProfileId))
+ path = fmt.Sprintf("%s/converse", encodedModelIdentifier)
+ }
+ }
+ }
+
+ // Create the signed request
+ req, err := provider.PrepareReq(path, jsonData, key)
+ if err != nil {
+ return nil, fmt.Errorf("error creating request: %v", err)
+ }
+
+ // Execute the request
+ resp, err := provider.client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to execute request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // Read response body
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %v", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("bedrock API error: %s", string(body))
+ }
+
+ var response BedrockChatResponse
+ if err := json.Unmarshal(body, &response); err != nil {
+ return nil, fmt.Errorf("failed to parse Bedrock response: %v", err)
+ }
+
+ // Parse raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("failed to parse raw response: %v", err)
+ }
+
+ var choices []interfaces.BifrostResponseChoice
+ for i, choice := range response.Output.Message.Content {
+ choices = append(choices, interfaces.BifrostResponseChoice{
+ Index: i,
+ Message: interfaces.BifrostResponseChoiceMessage{
+ Role: interfaces.RoleAssistant,
+ Content: &choice.Text,
+ },
+ FinishReason: &response.StopReason,
+ })
+ }
+
+ latency := float64(response.Metrics.Latency)
+
+ result := &interfaces.BifrostResponse{
+ Choices: choices,
+ Usage: interfaces.LLMUsage{
+ PromptTokens: response.Usage.InputTokens,
+ CompletionTokens: response.Usage.OutputTokens,
+ TotalTokens: response.Usage.TotalTokens,
+ },
+ Model: model,
+
+ ExtraFields: interfaces.BifrostResponseExtraFields{
+ Latency: &latency,
+ Provider: interfaces.Bedrock,
+ RawResponse: rawResponse,
+ },
+ }
+
+ return result, nil
+}
diff --git a/providers/cohere.go b/providers/cohere.go
new file mode 100644
index 000000000..4f14e3d67
--- /dev/null
+++ b/providers/cohere.go
@@ -0,0 +1,290 @@
+package providers
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+ "time"
+
+ "github.com/maximhq/bifrost/interfaces"
+ "github.com/valyala/fasthttp"
+)
+
+// CohereParameterDefinition represents a parameter definition for a Cohere tool
+type CohereParameterDefinition struct {
+ Type string `json:"type"`
+ Description *string `json:"description,omitempty"`
+ Required bool `json:"required"`
+}
+
+// CohereTool represents a tool definition for Cohere API
+type CohereTool struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ ParameterDefinitions map[string]CohereParameterDefinition `json:"parameter_definitions"`
+}
+
+type CohereToolCall struct {
+ Name string `json:"name"`
+ Parameters interface{} `json:"parameters"`
+}
+
+// CohereChatResponse represents the response from Cohere's chat API
+type CohereChatResponse struct {
+ ResponseID string `json:"response_id"`
+ Text string `json:"text"`
+ GenerationID string `json:"generation_id"`
+ ChatHistory []struct {
+ Role interfaces.ModelChatMessageRole `json:"role"`
+ Message string `json:"message"`
+ ToolCalls []CohereToolCall `json:"tool_calls"`
+ } `json:"chat_history"`
+ FinishReason string `json:"finish_reason"`
+ Meta struct {
+ APIVersion struct {
+ Version string `json:"version"`
+ } `json:"api_version"`
+ BilledUnits struct {
+ InputTokens float64 `json:"input_tokens"`
+ OutputTokens float64 `json:"output_tokens"`
+ } `json:"billed_units"`
+ Tokens struct {
+ InputTokens float64 `json:"input_tokens"`
+ OutputTokens float64 `json:"output_tokens"`
+ } `json:"tokens"`
+ } `json:"meta"`
+ ToolCalls []CohereToolCall `json:"tool_calls"`
+}
+
+// OpenAIProvider implements the Provider interface for OpenAI
+type CohereProvider struct {
+ client *fasthttp.Client
+}
+
+// NewOpenAIProvider creates a new OpenAI provider instance
+func NewCohereProvider(config *interfaces.ProviderConfig) *CohereProvider {
+ return &CohereProvider{
+ client: &fasthttp.Client{
+ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
+ WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
+ MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize,
+ },
+ }
+}
+
+func (provider *CohereProvider) GetProviderKey() interfaces.SupportedModelProvider {
+ return interfaces.Cohere
+}
+
+func (provider *CohereProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
+ return nil, fmt.Errorf("text completion is not supported by Cohere")
+}
+
+func (provider *CohereProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
+ // Get the last message and chat history
+ lastMessage := messages[len(messages)-1]
+ chatHistory := messages[:len(messages)-1]
+
+ // Transform chat history
+ var cohereHistory []map[string]interface{}
+ for _, msg := range chatHistory {
+ cohereHistory = append(cohereHistory, map[string]interface{}{
+ "role": msg.Role,
+ "message": msg.Content,
+ })
+ }
+
+ preparedParams := PrepareParams(params)
+
+ // Prepare request body
+ requestBody := MergeConfig(map[string]interface{}{
+ "message": lastMessage.Content,
+ "chat_history": cohereHistory,
+ "model": model,
+ }, preparedParams)
+
+ // Add tools if present
+ if params != nil && params.Tools != nil && len(*params.Tools) > 0 {
+ var tools []CohereTool
+ for _, tool := range *params.Tools {
+ parameterDefinitions := make(map[string]CohereParameterDefinition)
+ params := tool.Function.Parameters
+ for name, prop := range tool.Function.Parameters.Properties {
+ propMap, ok := prop.(map[string]interface{})
+ if ok {
+ paramDef := CohereParameterDefinition{
+ Required: slices.Contains(params.Required, name),
+ }
+
+ if typeStr, ok := propMap["type"].(string); ok {
+ paramDef.Type = typeStr
+ }
+
+ if desc, ok := propMap["description"].(string); ok {
+ paramDef.Description = &desc
+ }
+
+ parameterDefinitions[name] = paramDef
+ }
+ }
+
+ tools = append(tools, CohereTool{
+ Name: tool.Function.Name,
+ Description: tool.Function.Description,
+ ParameterDefinitions: parameterDefinitions,
+ })
+ }
+ requestBody["tools"] = tools
+ }
+
+ // Marshal request body
+ jsonBody, err := json.Marshal(requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("error marshaling request: %v", err)
+ }
+
+ // Create request
+ req := fasthttp.AcquireRequest()
+ resp := fasthttp.AcquireResponse()
+ defer fasthttp.ReleaseRequest(req)
+ defer fasthttp.ReleaseResponse(resp)
+
+ req.SetRequestURI("https://api.cohere.ai/v1/chat")
+ req.Header.SetMethod("POST")
+ req.Header.SetContentType("application/json")
+ req.Header.Set("Authorization", "Bearer "+key)
+ req.SetBody(jsonBody)
+
+ // Make request
+ if err := provider.client.Do(req, resp); err != nil {
+ return nil, fmt.Errorf("error making request: %v", err)
+ }
+
+ // Handle error response
+ if resp.StatusCode() != fasthttp.StatusOK {
+ return nil, fmt.Errorf("cohere error: %s", resp.Body())
+ }
+
+ // Read response body
+ body := resp.Body()
+
+ // Decode response
+ var response CohereChatResponse
+ if err := json.Unmarshal(body, &response); err != nil {
+ return nil, fmt.Errorf("failed to parse Bedrock response: %v", err)
+ }
+
+ // Parse raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("failed to parse raw response: %v", err)
+ }
+
+ // Transform tool calls if present
+ var toolCalls []interfaces.ToolCall
+ if response.ToolCalls != nil {
+ for _, tool := range response.ToolCalls {
+ function := interfaces.FunctionCall{
+ Name: &tool.Name,
+ }
+
+ args, err := json.Marshal(tool.Parameters)
+ if err != nil {
+ function.Arguments = fmt.Sprintf("%v", tool.Parameters)
+ } else {
+ function.Arguments = string(args)
+ }
+
+ toolCalls = append(toolCalls, interfaces.ToolCall{
+ Function: function,
+ })
+ }
+ }
+
+ // Get role and content from the last message in chat history
+ var role interfaces.ModelChatMessageRole
+ var content string
+ if len(response.ChatHistory) > 0 {
+ lastMsg := response.ChatHistory[len(response.ChatHistory)-1]
+ role = lastMsg.Role
+ content = lastMsg.Message
+ } else {
+ role = interfaces.RoleChatbot
+ content = response.Text
+ }
+
+ // Create completion result
+ result := &interfaces.BifrostResponse{
+ ID: response.ResponseID,
+ Choices: []interfaces.BifrostResponseChoice{
+ {
+ Index: 0,
+ Message: interfaces.BifrostResponseChoiceMessage{
+ Role: role,
+ Content: &content,
+ ToolCalls: &toolCalls,
+ },
+ FinishReason: &response.FinishReason,
+ },
+ },
+ Usage: interfaces.LLMUsage{
+ PromptTokens: int(response.Meta.Tokens.InputTokens),
+ CompletionTokens: int(response.Meta.Tokens.OutputTokens),
+ TotalTokens: int(response.Meta.Tokens.InputTokens + response.Meta.Tokens.OutputTokens),
+ },
+
+ ExtraFields: interfaces.BifrostResponseExtraFields{
+ Provider: interfaces.Cohere,
+ BilledUsage: &interfaces.BilledLLMUsage{
+ PromptTokens: float64Ptr(response.Meta.BilledUnits.InputTokens),
+ CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens),
+ },
+ ChatHistory: convertChatHistory(response.ChatHistory),
+ RawResponse: rawResponse,
+ },
+ Model: model,
+ }
+
+ return result, nil
+}
+
+// Helper function to convert chat history to the correct type
+func convertChatHistory(history []struct {
+ Role interfaces.ModelChatMessageRole `json:"role"`
+ Message string `json:"message"`
+ ToolCalls []CohereToolCall `json:"tool_calls"`
+}) *[]interfaces.BifrostResponseChoiceMessage {
+ converted := make([]interfaces.BifrostResponseChoiceMessage, len(history))
+ for i, msg := range history {
+ var toolCalls []interfaces.ToolCall
+ if msg.ToolCalls != nil {
+ for _, tool := range msg.ToolCalls {
+ function := interfaces.FunctionCall{
+ Name: &tool.Name,
+ }
+
+ args, err := json.Marshal(tool.Parameters)
+ if err != nil {
+ function.Arguments = fmt.Sprintf("%v", tool.Parameters)
+ } else {
+ function.Arguments = string(args)
+ }
+
+ toolCalls = append(toolCalls, interfaces.ToolCall{
+ Function: function,
+ })
+ }
+ }
+ converted[i] = interfaces.BifrostResponseChoiceMessage{
+ Role: msg.Role,
+ Content: &msg.Message,
+ ToolCalls: &toolCalls,
+ }
+ }
+ return &converted
+}
+
+// Helper function to create a pointer to a float64
+func float64Ptr(f float64) *float64 {
+ return &f
+}
diff --git a/providers/openai.go b/providers/openai.go
index bdc1385d4..3c12ec768 100644
--- a/providers/openai.go
+++ b/providers/openai.go
@@ -1,24 +1,38 @@
package providers
import (
- "bifrost/interfaces"
- "bytes"
"encoding/json"
"fmt"
- "net/http"
"time"
+
+ "github.com/maximhq/bifrost/interfaces"
+ "github.com/valyala/fasthttp"
)
+type OpenAIResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"` // text.completion or chat.completion
+ Choices []interfaces.BifrostResponseChoice `json:"choices"`
+ Model string `json:"model"`
+ Created int `json:"created"` // The Unix timestamp (in seconds).
+ ServiceTier *string `json:"service_tier"`
+ SystemFingerprint *string `json:"system_fingerprint"`
+ Usage interfaces.LLMUsage `json:"usage"`
+}
+
// OpenAIProvider implements the Provider interface for OpenAI
type OpenAIProvider struct {
- //* Do we even need it?
- client *http.Client
+ client *fasthttp.Client
}
// NewOpenAIProvider creates a new OpenAI provider instance
-func NewOpenAIProvider() *OpenAIProvider {
+func NewOpenAIProvider(config *interfaces.ProviderConfig) *OpenAIProvider {
return &OpenAIProvider{
- client: &http.Client{Timeout: time.Second * 30},
+ client: &fasthttp.Client{
+ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
+ WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
+ MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize,
+ },
}
}
@@ -27,35 +41,14 @@ func (provider *OpenAIProvider) GetProviderKey() interfaces.SupportedModelProvid
}
// TextCompletion performs text completion
-func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
+func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
return nil, fmt.Errorf("text completion is not supported by OpenAI")
}
-// sanitizeParameters cleans up the parameters for OpenAI
-func (provider *OpenAIProvider) sanitizeParameters(params *interfaces.ModelParameters) *interfaces.ModelParameters {
- sanitized := params
- if sanitized == nil {
- return nil
- }
-
- if params.ExtraParams != nil {
- // For logprobs, if it's disabled, we remove top_logprobs
- if _, exists := params.ExtraParams["logprobs"]; !exists {
- delete(sanitized.ExtraParams, "top_logprobs")
- }
- }
-
- return sanitized
-}
-
-// ChatCompletion implements chat completion using OpenAI's API
-func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
- startTime := time.Now()
-
+func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
// Format messages for OpenAI API
var openAIMessages []map[string]interface{}
for _, msg := range messages {
-
var content any
if msg.Content != nil {
content = msg.Content
@@ -69,8 +62,6 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int
})
}
- // Sanitize parameters
- params = provider.sanitizeParameters(params)
preparedParams := PrepareParams(params)
requestBody := MergeConfig(map[string]interface{}{
@@ -84,75 +75,55 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int
}
// Create request
- req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(jsonBody))
- if err != nil {
- return nil, fmt.Errorf("error creating request: %v", err)
- }
-
- // Add headers
- req.Header.Set("Content-Type", "application/json")
+ req := fasthttp.AcquireRequest()
+ resp := fasthttp.AcquireResponse()
+ defer fasthttp.ReleaseRequest(req)
+ defer fasthttp.ReleaseResponse(resp)
+
+ req.SetRequestURI("https://api.openai.com/v1/chat/completions")
+ req.Header.SetMethod("POST")
+ req.Header.SetContentType("application/json")
req.Header.Set("Authorization", "Bearer "+key)
+ req.SetBody(jsonBody)
// Make request
- resp, err := provider.client.Do(req)
- if err != nil {
+ if err := provider.client.Do(req, resp); err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
- defer resp.Body.Close()
-
- latency := time.Since(startTime).Seconds()
// Handle error response
- if resp.StatusCode != http.StatusOK {
- var errorResp struct {
- Error struct {
- Message string `json:"message"`
- Type string `json:"type"`
- Param any `json:"param"`
- Code string `json:"code"`
- } `json:"error"`
- }
- if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil {
- return nil, fmt.Errorf("error decoding error response: %v", err)
- }
- return nil, fmt.Errorf("OpenAI error: %s", errorResp.Error.Message)
+ if resp.StatusCode() != fasthttp.StatusOK {
+ return nil, fmt.Errorf("OpenAI error: %s", resp.Body())
}
- // Decode response
- var rawResult struct {
- ID string `json:"id"`
- Choices []interfaces.CompletionResultChoice `json:"choices"`
- Usage interfaces.LLMUsage `json:"usage"`
- Model string `json:"model"`
- Created interface{} `json:"created"`
- }
+ body := resp.Body()
- if err := json.NewDecoder(resp.Body).Decode(&rawResult); err != nil {
- return nil, fmt.Errorf("error decoding response: %v", err)
+ // Decode structured response
+ var response OpenAIResponse
+ if err := json.Unmarshal(body, &response); err != nil {
+ return nil, fmt.Errorf("error decoding structured response: %v", err)
}
- // Convert the raw result to CompletionResult
- result := &interfaces.CompletionResult{
- ID: rawResult.ID,
- Choices: rawResult.Choices,
- Usage: rawResult.Usage,
- Model: rawResult.Model,
+ // Decode raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("error decoding raw response: %v", err)
}
- // Handle the created field conversion
- if rawResult.Created != nil {
- switch v := rawResult.Created.(type) {
- case float64:
- // Convert Unix timestamp to string
- result.Created = fmt.Sprintf("%d", int64(v))
- case string:
- result.Created = v
- }
+ result := &interfaces.BifrostResponse{
+ ID: response.ID,
+ Choices: response.Choices,
+ Object: response.Object,
+ Usage: response.Usage,
+ ServiceTier: response.ServiceTier,
+ SystemFingerprint: response.SystemFingerprint,
+ Created: response.Created,
+ Model: response.Model,
+ ExtraFields: interfaces.BifrostResponseExtraFields{
+ Provider: interfaces.OpenAI,
+ RawResponse: rawResponse,
+ },
}
- // Add provider-specific information
- result.Provider = interfaces.OpenAI
- result.Usage.Latency = latency
-
return result, nil
}
diff --git a/providers/utils.go b/providers/utils.go
index bd925de75..9acbef7ab 100644
--- a/providers/utils.go
+++ b/providers/utils.go
@@ -1,8 +1,24 @@
package providers
import (
- "bifrost/interfaces"
+ "bytes"
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "net/http"
"reflect"
+ "strings"
+ "time"
+
+ "github.com/maximhq/bifrost/interfaces"
+
+ "maps"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
+ "github.com/aws/aws-sdk-go-v2/config"
)
// MergeConfig merges default config with custom parameters
@@ -35,7 +51,7 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} {
typ := val.Type()
// Iterate through all fields
- for i := 0; i < val.NumField(); i++ {
+ for i := range val.NumField() {
field := val.Field(i)
fieldType := typ.Field(i)
@@ -50,6 +66,9 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} {
continue
}
+ // Strip out ,omitempty and others from the tag
+ jsonTag = strings.Split(jsonTag, ",")[0]
+
// Handle pointer fields
if field.Kind() == reflect.Ptr && !field.IsNil() {
flatParams[jsonTag] = field.Elem().Interface()
@@ -57,9 +76,64 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} {
}
// Handle ExtraParams
- for k, v := range params.ExtraParams {
- flatParams[k] = v
- }
+ maps.Copy(flatParams, params.ExtraParams)
return flatParams
}
+
+func SignAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken *string, region, service string) error {
+ // Set required headers before signing
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+
+ // Calculate SHA256 hash of the request body
+ var bodyHash string
+ if req.Body != nil {
+ bodyBytes, err := io.ReadAll(req.Body)
+ if err != nil {
+ return fmt.Errorf("failed to read request body: %v", err)
+ }
+ // Restore the body for subsequent reads
+ req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
+
+ hash := sha256.Sum256(bodyBytes)
+ bodyHash = hex.EncodeToString(hash[:])
+ } else {
+ // For empty body, use the hash of an empty string
+ hash := sha256.Sum256([]byte{})
+ bodyHash = hex.EncodeToString(hash[:])
+ }
+
+ cfg, err := config.LoadDefaultConfig(context.TODO(),
+ config.WithRegion(region),
+ config.WithCredentialsProvider(aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) {
+ creds := aws.Credentials{
+ AccessKeyID: accessKey,
+ SecretAccessKey: secretKey,
+ }
+ if sessionToken != nil {
+ creds.SessionToken = *sessionToken
+ }
+ return creds, nil
+ })),
+ )
+ if err != nil {
+ return fmt.Errorf("failed to load AWS config: %v", err)
+ }
+
+ // Create the AWS signer
+ signer := v4.NewSigner()
+
+ // Get credentials
+ creds, err := cfg.Credentials.Retrieve(context.TODO())
+ if err != nil {
+ return fmt.Errorf("failed to retrieve credentials: %v", err)
+ }
+
+ // Sign the request with AWS Signature V4
+ if err := signer.SignHTTP(context.TODO(), creds, req, bodyHash, service, region, time.Now()); err != nil {
+ return fmt.Errorf("failed to sign request: %v", err)
+ }
+
+ return nil
+}
diff --git a/tests/account.go b/tests/account.go
index 81aa72e34..174fee9b7 100644
--- a/tests/account.go
+++ b/tests/account.go
@@ -1,119 +1,110 @@
package tests
import (
- "bifrost/interfaces"
- "bifrost/providers"
"fmt"
-)
-
-// BaseAccount provides a basic implementation of the Account interface
-type BaseAccount struct {
- providers map[string]interfaces.Provider
- keys map[string][]interfaces.Key
- config map[string]interfaces.ConcurrencyAndBufferSize
-}
-
-type ProviderConfig map[string]struct {
- Keys []interfaces.Key `json:"keys"`
- ConcurrencyConfig interfaces.ConcurrencyAndBufferSize `json:"concurrency_config"`
-}
-
-func (baseAccount *BaseAccount) Init(config ProviderConfig) error {
- baseAccount.providers = make(map[string]interfaces.Provider)
- baseAccount.keys = make(map[string][]interfaces.Key)
- baseAccount.config = make(map[string]interfaces.ConcurrencyAndBufferSize)
-
- for providerKey, providerData := range config {
- // Create provider instance based on the key
- provider, err := baseAccount.createProvider(providerKey, providerData.Keys)
- if err != nil {
- return fmt.Errorf("failed to create provider %s: %v", providerKey, err)
- }
-
- fmt.Println("✅ provider created")
-
- // Add provider to the account
- baseAccount.AddProvider(provider)
+ "os"
- // Add keys for the provider
- for _, keyData := range providerData.Keys {
- key := interfaces.Key{
- Value: keyData.Value,
- Models: keyData.Models,
- Weight: keyData.Weight,
- }
+ "github.com/maximhq/bifrost/interfaces"
- baseAccount.AddKey(providerKey, key)
- }
+ "github.com/maximhq/maxim-go"
+)
- // Set provider configuration
- baseAccount.SetProviderConcurrencyConfig(providerKey, providerData.ConcurrencyConfig)
- }
+// BaseAccount provides a basic implementation of the Account interface for Anthropic and OpenAI providers
+type BaseAccount struct{}
- return nil
+// GetInitiallyConfiguredProviderKeys returns all provider keys
+func (baseAccount *BaseAccount) GetInitiallyConfiguredProviders() ([]interfaces.SupportedModelProvider, error) {
+ return []interfaces.SupportedModelProvider{interfaces.OpenAI, interfaces.Anthropic, interfaces.Bedrock}, nil
}
-// createProvider creates a new provider instance based on the provider key
-func (ba *BaseAccount) createProvider(providerKey string, keys []interfaces.Key) (interfaces.Provider, error) {
- if len(keys) == 0 {
- return nil, fmt.Errorf("no keys found for provider: %s", providerKey)
- }
-
- switch interfaces.SupportedModelProvider(providerKey) {
+// GetKeysForProvider returns all keys associated with a provider
+func (baseAccount *BaseAccount) GetKeysForProvider(providerKey interfaces.SupportedModelProvider) ([]interfaces.Key, error) {
+ switch providerKey {
case interfaces.OpenAI:
- return providers.NewOpenAIProvider(), nil
+ return []interfaces.Key{
+ {
+ Value: os.Getenv("OPEN_AI_API_KEY"),
+ Models: []string{"gpt-4o-mini"},
+ Weight: 1.0,
+ },
+ }, nil
case interfaces.Anthropic:
- return providers.NewAnthropicProvider(), nil
+ return []interfaces.Key{
+ {
+ Value: os.Getenv("ANTHROPIC_API_KEY"),
+ Models: []string{"claude-3-7-sonnet-20250219", "claude-2.1"},
+ Weight: 1.0,
+ },
+ }, nil
+ case interfaces.Bedrock:
+ return []interfaces.Key{
+ {
+ Value: os.Getenv("BEDROCK_API_KEY"),
+ Models: []string{"anthropic.claude-v2:1", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"},
+ Weight: 1.0,
+ },
+ }, nil
+ case interfaces.Cohere:
+ return []interfaces.Key{
+ {
+ Value: os.Getenv("COHERE_API_KEY"),
+ Models: []string{"command-a-03-2025"},
+ Weight: 1.0,
+ },
+ }, nil
default:
return nil, fmt.Errorf("unsupported provider: %s", providerKey)
}
}
-// GetInitiallyConfiguredProviders returns all configured providers
-func (ba *BaseAccount) GetInitiallyConfiguredProviders() ([]interfaces.Provider, error) {
- providers := make([]interfaces.Provider, 0, len(ba.providers))
- for _, provider := range ba.providers {
- providers = append(providers, provider)
- }
-
- return providers, nil
-}
-
-// GetKeysForProvider returns all keys associated with a provider
-func (ba *BaseAccount) GetKeysForProvider(provider interfaces.Provider) ([]interfaces.Key, error) {
- providerKey := string(provider.GetProviderKey())
- if keys, exists := ba.keys[providerKey]; exists {
- return keys, nil
- }
-
- return nil, fmt.Errorf("no keys found for provider: %s", providerKey)
-}
-
// GetConcurrencyAndBufferSizeForProvider returns the concurrency and buffer size settings for a provider
-func (ba *BaseAccount) GetConcurrencyAndBufferSizeForProvider(provider interfaces.Provider) (interfaces.ConcurrencyAndBufferSize, error) {
- providerKey := string(provider.GetProviderKey())
- if config, exists := ba.config[providerKey]; exists {
- return config, nil
+func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.SupportedModelProvider) (*interfaces.ProviderConfig, error) {
+ switch providerKey {
+ case interfaces.OpenAI:
+ return &interfaces.ProviderConfig{
+ NetworkConfig: interfaces.NetworkConfig{
+ DefaultRequestTimeoutInSeconds: 30,
+ },
+ ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{
+ Concurrency: 3,
+ BufferSize: 10,
+ },
+ }, nil
+ case interfaces.Anthropic:
+ return &interfaces.ProviderConfig{
+ NetworkConfig: interfaces.NetworkConfig{
+ DefaultRequestTimeoutInSeconds: 30,
+ },
+ ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{
+ Concurrency: 3,
+ BufferSize: 10,
+ },
+ }, nil
+ case interfaces.Bedrock:
+ return &interfaces.ProviderConfig{
+ NetworkConfig: interfaces.NetworkConfig{
+ DefaultRequestTimeoutInSeconds: 30,
+ },
+ MetaConfig: &interfaces.MetaConfig{
+ SecretAccessKey: maxim.StrPtr(os.Getenv("BEDROCK_ACCESS_KEY")),
+ Region: maxim.StrPtr("us-east-1"),
+ },
+ ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{
+ Concurrency: 3,
+ BufferSize: 10,
+ },
+ }, nil
+ case interfaces.Cohere:
+ return &interfaces.ProviderConfig{
+ NetworkConfig: interfaces.NetworkConfig{
+ DefaultRequestTimeoutInSeconds: 30,
+ },
+ ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{
+ Concurrency: 3,
+ BufferSize: 10,
+ },
+ }, nil
+ default:
+ return nil, fmt.Errorf("unsupported provider: %s", providerKey)
}
-
- // Default values if not configured
- return interfaces.ConcurrencyAndBufferSize{
- Concurrency: 5,
- BufferSize: 100,
- }, nil
-}
-
-// AddProvider adds a new provider to the account
-func (ba *BaseAccount) AddProvider(provider interfaces.Provider) {
- ba.providers[string(provider.GetProviderKey())] = provider
-}
-
-// AddKey adds a new key for a provider
-func (ba *BaseAccount) AddKey(providerKey string, key interfaces.Key) {
- ba.keys[providerKey] = append(ba.keys[providerKey], key)
-}
-
-// SetProviderConfig sets the concurrency and buffer size for a provider
-func (ba *BaseAccount) SetProviderConcurrencyConfig(providerKey string, config interfaces.ConcurrencyAndBufferSize) {
- ba.config[providerKey] = config
}
diff --git a/tests/anthropic_test.go b/tests/anthropic_test.go
index 952a0c16e..0fa2eb793 100644
--- a/tests/anthropic_test.go
+++ b/tests/anthropic_test.go
@@ -1,41 +1,48 @@
package tests
import (
- "bifrost"
- "bifrost/interfaces"
+ "context"
"fmt"
"testing"
"time"
+
+ "github.com/maximhq/bifrost"
+ "github.com/maximhq/bifrost/interfaces"
)
// setupAnthropicRequests sends multiple test requests to Anthropic
func setupAnthropicRequests(bifrost *bifrost.Bifrost) {
- anthropicMessages := []string{
- "What's your favorite programming language?",
- "Can you help me write a Go function?",
- "What's the best way to learn programming?",
- "Tell me about artificial intelligence.",
+ ctx := context.Background()
+
+ maxTokens := 4096
+
+ params := interfaces.ModelParameters{
+ MaxTokens: &maxTokens,
}
+ // Text completion request
go func() {
- config := interfaces.ModelParameters{
- ExtraParams: map[string]interface{}{
- "max_tokens_to_sample": 4096,
- },
- }
+ text := "Hello world!"
- result, err := bifrost.TextCompletionRequest(interfaces.Anthropic, "claude-2.1", "Hello world!", &config)
+ result, err := bifrost.TextCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{
+ Model: "claude-2.1",
+ Input: interfaces.RequestInput{
+ TextInput: &text,
+ },
+ Params: ¶ms,
+ }, ctx)
if err != nil {
fmt.Println("Error:", err)
} else {
- fmt.Println("🤖 Text Completion Result:", result.Choices[0].Message.Content)
+ fmt.Println("🤖 Text Completion Result:", *result.Choices[0].Message.Content)
}
}()
- config := interfaces.ModelParameters{
- ExtraParams: map[string]interface{}{
- "max_tokens": 4096,
- },
+ // Regular chat completion requests
+ anthropicMessages := []string{
+ "Hello! How are you today?",
+ "Tell me a joke!",
+ "What's your favorite programming language?",
}
for i, message := range anthropicMessages {
@@ -44,15 +51,86 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) {
time.Sleep(delay)
messages := []interfaces.Message{
{
- Role: interfaces.UserRole,
+ Role: interfaces.RoleUser,
Content: &msg,
},
}
- result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, "claude-3-7-sonnet-20250219", messages, &config)
+ result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{
+ Model: "claude-3-7-sonnet-20250219",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+
if err != nil {
fmt.Printf("Error in Anthropic request %d: %v\n", index+1, err)
} else {
- fmt.Printf("🤖 Chat Completion Result %d: %s\n", index+1, result.Choices[0].Message.Content)
+ fmt.Printf("🤖 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content)
+ }
+ }(message, delay, i)
+ }
+
+ // Tool calls test
+ setupAnthropicToolCalls(bifrost, ctx)
+}
+
+// setupAnthropicToolCalls tests Anthropic's function calling capability
+func setupAnthropicToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) {
+ anthropicMessages := []string{
+ "What's the weather like in Mumbai?",
+ }
+
+ maxTokens := 4096
+
+ params := interfaces.ModelParameters{
+ Tools: &[]interfaces.Tool{{
+ Type: "function",
+ Function: interfaces.Function{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: interfaces.FunctionParameters{
+ Type: "object",
+ Properties: map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ Required: []string{"location"},
+ },
+ },
+ }},
+ MaxTokens: &maxTokens,
+ }
+
+ for i, message := range anthropicMessages {
+ delay := time.Duration(500+100*i) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{
+ Model: "claude-3-7-sonnet-20250219",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+
+ if err != nil {
+ fmt.Printf("Error in Anthropic tool call request %d: %v\n", index+1, err)
+ } else {
+ toolCall := *result.Choices[1].Message.ToolCalls
+ fmt.Printf("🤖 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments)
}
}(message, delay, i)
}
diff --git a/tests/bedrock_test.go b/tests/bedrock_test.go
new file mode 100644
index 000000000..9c1791931
--- /dev/null
+++ b/tests/bedrock_test.go
@@ -0,0 +1,154 @@
+package tests
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/maximhq/bifrost"
+ "github.com/maximhq/bifrost/interfaces"
+)
+
+// setupBedrockRequests sends multiple test requests to Bedrock
+func setupBedrockRequests(bifrost *bifrost.Bifrost) {
+ ctx := context.Background()
+
+ maxTokens := 4096
+
+ params := interfaces.ModelParameters{
+ MaxTokens: &maxTokens,
+ }
+
+ // Text completion request
+ go func() {
+ text := "\n\nHuman:\n\nAssistant:"
+
+ result, err := bifrost.TextCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{
+ Model: "anthropic.claude-v2:1",
+ Input: interfaces.RequestInput{
+ TextInput: &text,
+ },
+ Params: ¶ms,
+ }, ctx)
+ if err != nil {
+ fmt.Println("Error:", err)
+ } else {
+ fmt.Println("🤖 Text Completion Result:", *result.Choices[0].Message.Content)
+ }
+ }()
+
+ // Regular chat completion requests
+ bedrockMessages := []string{
+ "Hello! How are you today?",
+ "Tell me a joke!",
+ "What's your favorite programming language?",
+ }
+
+ for i, message := range bedrockMessages {
+ delay := time.Duration(500+100*i) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{
+ Model: "anthropic.claude-3-sonnet-20240229-v1:0",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+
+ if err != nil {
+ fmt.Printf("Error in Bedrock request %d: %v\n", index+1, err)
+ } else {
+ fmt.Printf("🤖 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content)
+ }
+ }(message, delay, i)
+ }
+
+ // Tool calls test
+ setupBedrockToolCalls(bifrost, ctx)
+}
+
+// setupBedrockToolCalls tests Bedrock's function calling capability
+func setupBedrockToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) {
+ bedrockMessages := []string{
+ "What's the weather like in Mumbai?",
+ }
+
+ maxTokens := 4096
+
+ params := interfaces.ModelParameters{
+ Tools: &[]interfaces.Tool{{
+ Type: "function",
+ Function: interfaces.Function{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: interfaces.FunctionParameters{
+ Type: "object",
+ Properties: map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ Required: []string{"location"},
+ },
+ },
+ }},
+ MaxTokens: &maxTokens,
+ }
+
+ for i, message := range bedrockMessages {
+ delay := time.Duration(500+100*i) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{
+ Model: "anthropic.claude-3-sonnet-20240229-v1:0",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+
+ if err != nil {
+ fmt.Printf("Error in Bedrock tool call request %d: %v\n", index+1, err)
+ } else {
+ if result.Choices[0].Message.ToolCalls != nil && len(*result.Choices[0].Message.ToolCalls) > 0 {
+ toolCall := *result.Choices[0].Message.ToolCalls
+ fmt.Printf("🤖 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments)
+ } else {
+ fmt.Printf("🤖 No tool calls in response %d\n", index+1)
+ fmt.Println("Raw JSON Response", result.ExtraFields.RawResponse)
+ }
+ }
+ }(message, delay, i)
+ }
+}
+
+func TestBedrock(t *testing.T) {
+ bifrost, err := getBifrost()
+ if err != nil {
+ t.Fatalf("Error initializing bifrost: %v", err)
+ return
+ }
+
+ setupBedrockRequests(bifrost)
+
+ bifrost.Cleanup()
+}
diff --git a/tests/cohere_test.go b/tests/cohere_test.go
new file mode 100644
index 000000000..750d59399
--- /dev/null
+++ b/tests/cohere_test.go
@@ -0,0 +1,137 @@
+package tests
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/maximhq/bifrost"
+ "github.com/maximhq/bifrost/interfaces"
+)
+
+// setupCohereRequests sends multiple test requests to Cohere
+func setupCohereRequests(bifrost *bifrost.Bifrost) {
+ text := "Hello world!"
+ ctx := context.Background()
+
+ // Text completion request
+ go func() {
+ result, err := bifrost.TextCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{
+ Model: "command-a-03-2025",
+ Input: interfaces.RequestInput{
+ TextInput: &text,
+ },
+ Params: nil,
+ }, ctx)
+ if err != nil {
+ fmt.Println("Error:", err)
+ } else {
+ fmt.Println("🐒 Text Completion Result:", result.Choices[0].Message.Content)
+ }
+ }()
+
+ // Regular chat completion requests
+ cohereMessages := []string{
+ "Hello! How are you today?",
+ "Tell me a joke!",
+ "What's your favorite programming language?",
+ }
+
+ for i, message := range cohereMessages {
+ delay := time.Duration(100*(i+1)) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{
+ Model: "command-a-03-2025",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: nil,
+ }, ctx)
+ if err != nil {
+ fmt.Printf("Error in Cohere request %d: %v\n", index+1, err)
+ } else {
+ fmt.Printf("🐒 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content)
+ }
+ }(message, delay, i)
+ }
+
+ // Tool calls test
+ setupCohereToolCalls(bifrost, ctx)
+}
+
+// setupCohereToolCalls tests Cohere's function calling capability
+func setupCohereToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) {
+ cohereMessages := []string{
+ "What's the weather like in Mumbai?",
+ }
+
+ params := interfaces.ModelParameters{
+ Tools: &[]interfaces.Tool{{
+ Type: "function",
+ Function: interfaces.Function{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: interfaces.FunctionParameters{
+ Type: "object",
+ Properties: map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ Required: []string{"location"},
+ },
+ },
+ }},
+ }
+
+ for i, message := range cohereMessages {
+ delay := time.Duration(100*(i+1)) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{
+ Model: "command-a-03-2025",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+ if err != nil {
+ fmt.Printf("Error in Cohere tool call request %d: %v\n", index+1, err)
+ } else {
+ toolCall := *result.Choices[0].Message.ToolCalls
+ fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments)
+ }
+ }(message, delay, i)
+ }
+}
+
+func TestCohere(t *testing.T) {
+ bifrost, err := getBifrost()
+ if err != nil {
+ t.Fatalf("Error initializing bifrost: %v", err)
+ return
+ }
+
+ setupCohereRequests(bifrost)
+
+ bifrost.Cleanup()
+}
diff --git a/tests/openai_test.go b/tests/openai_test.go
index 1e82a31d5..dd73c6855 100644
--- a/tests/openai_test.go
+++ b/tests/openai_test.go
@@ -1,18 +1,29 @@
package tests
import (
- "bifrost"
- "bifrost/interfaces"
+ "context"
"fmt"
"testing"
"time"
+
+ "github.com/maximhq/bifrost"
+ "github.com/maximhq/bifrost/interfaces"
)
// setupOpenAIRequests sends multiple test requests to OpenAI
func setupOpenAIRequests(bifrost *bifrost.Bifrost) {
+ text := "Hello world!"
+ ctx := context.Background()
+
// Text completion request
go func() {
- result, err := bifrost.TextCompletionRequest(interfaces.OpenAI, "gpt-4o-mini", "Hello world!", nil)
+ result, err := bifrost.TextCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{
+ Model: "gpt-4o-mini",
+ Input: interfaces.RequestInput{
+ TextInput: &text,
+ },
+ Params: nil,
+ }, ctx)
if err != nil {
fmt.Println("Error:", err)
} else {
@@ -20,10 +31,9 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) {
}
}()
- // Chat completion requests with different messages and delays
+ // Regular chat completion requests
openAIMessages := []string{
"Hello! How are you today?",
- "What's the weather like?",
"Tell me a joke!",
"What's your favorite programming language?",
}
@@ -34,15 +44,81 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) {
time.Sleep(delay)
messages := []interfaces.Message{
{
- Role: interfaces.UserRole,
+ Role: interfaces.RoleUser,
Content: &msg,
},
}
- result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, "gpt-4o-mini", messages, nil)
+ result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{
+ Model: "gpt-4o-mini",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: nil,
+ }, ctx)
if err != nil {
fmt.Printf("Error in OpenAI request %d: %v\n", index+1, err)
} else {
- fmt.Printf("🐒 Chat Completion Result %d: %s\n", index+1, result.Choices[0].Message.Content)
+ fmt.Printf("🐒 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content)
+ }
+ }(message, delay, i)
+ }
+
+ // Tool calls test
+ setupOpenAIToolCalls(bifrost, ctx)
+}
+
+// setupOpenAIToolCalls tests OpenAI's function calling capability
+func setupOpenAIToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) {
+ openAIMessages := []string{
+ "What's the weather like in Mumbai?",
+ }
+
+ params := interfaces.ModelParameters{
+ Tools: &[]interfaces.Tool{{
+ Type: "function",
+ Function: interfaces.Function{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: interfaces.FunctionParameters{
+ Type: "object",
+ Properties: map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ Required: []string{"location"},
+ },
+ },
+ }},
+ }
+
+ for i, message := range openAIMessages {
+ delay := time.Duration(100*(i+1)) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{
+ Model: "gpt-4o-mini",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+ if err != nil {
+ fmt.Printf("Error in OpenAI tool call request %d: %v\n", index+1, err)
+ } else {
+ toolCall := *result.Choices[0].Message.ToolCalls
+ fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments)
}
}(message, delay, i)
}
diff --git a/tests/plugin.go b/tests/plugin.go
new file mode 100644
index 000000000..12395435a
--- /dev/null
+++ b/tests/plugin.go
@@ -0,0 +1,56 @@
+package tests
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/maximhq/bifrost/interfaces"
+
+ "github.com/maximhq/maxim-go"
+ "github.com/maximhq/maxim-go/logging"
+)
+
+// Define a custom type for context key to avoid collisions
+type contextKey string
+
+const (
+ traceIDKey contextKey = "traceID"
+)
+
+type Plugin struct {
+ logger *logging.Logger
+}
+
+func (plugin *Plugin) PreHook(ctx *context.Context, req *interfaces.BifrostRequest) (*interfaces.BifrostRequest, error) {
+ traceID := time.Now().Format("20060102_150405000")
+
+ trace := plugin.logger.Trace(&logging.TraceConfig{
+ Id: traceID,
+ Name: maxim.StrPtr("bifrost"),
+ })
+
+ trace.SetInput(fmt.Sprintf("New Request Incoming: %v", req))
+
+ if ctx != nil {
+ // Store traceID in context
+ *ctx = context.WithValue(*ctx, traceIDKey, traceID)
+ }
+
+ return req, nil
+}
+
+func (plugin *Plugin) PostHook(ctxRef *context.Context, res *interfaces.BifrostResponse) (*interfaces.BifrostResponse, error) {
+ // Get traceID from context
+ if ctxRef != nil {
+ ctx := *ctxRef
+ traceID, ok := ctx.Value(traceIDKey).(string)
+ if !ok {
+ return res, fmt.Errorf("traceID not found in context")
+ }
+
+ plugin.logger.SetTraceOutput(traceID, fmt.Sprintf("Response: %v", res))
+ }
+
+ return res, nil
+}
diff --git a/tests/setup.go b/tests/setup.go
index efa82723f..812a6ccf3 100644
--- a/tests/setup.go
+++ b/tests/setup.go
@@ -1,12 +1,16 @@
package tests
import (
- "bifrost"
- "bifrost/interfaces"
+ "fmt"
"log"
"os"
+ "github.com/maximhq/bifrost"
+ "github.com/maximhq/bifrost/interfaces"
+
"github.com/joho/godotenv"
+ "github.com/maximhq/maxim-go"
+ "github.com/maximhq/maxim-go/logging"
)
func loadEnv() {
@@ -16,39 +20,33 @@ func loadEnv() {
}
}
+func getPlugin() (interfaces.Plugin, error) {
+ loadEnv()
+
+ mx := maxim.Init(&maxim.MaximSDKConfig{ApiKey: os.Getenv("MAXIM_API_KEY")})
+
+ logger, err := mx.GetLogger(&logging.LoggerConfig{Id: os.Getenv("MAXIM_LOGGER_ID")})
+ if err != nil {
+ return nil, err
+ }
+
+ plugin := &Plugin{logger}
+
+ return plugin, nil
+}
+
func getBifrost() (*bifrost.Bifrost, error) {
loadEnv()
account := BaseAccount{}
- if err := account.Init(
- ProviderConfig{
- "openai": {
- Keys: []interfaces.Key{
- {Value: os.Getenv("OPEN_AI_API_KEY"), Weight: 1.0, Models: []string{"gpt-4o-mini"}},
- },
- ConcurrencyConfig: interfaces.ConcurrencyAndBufferSize{
- Concurrency: 3,
- BufferSize: 10,
- },
- },
- "anthropic": {
- Keys: []interfaces.Key{
- {Value: os.Getenv("ANTHROPIC_API_KEY"), Weight: 1.0, Models: []string{"claude-3-7-sonnet-20250219", "claude-2.1"}},
- },
- ConcurrencyConfig: interfaces.ConcurrencyAndBufferSize{
- Concurrency: 3,
- BufferSize: 10,
- },
- },
- },
- ); err != nil {
- log.Fatal("Error initializing account:", err)
+ plugin, err := getPlugin()
+ if err != nil {
+ fmt.Println("Error setting up the plugin:", err)
return nil, err
}
- bifrost, err := bifrost.Init(&account)
+ bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin})
if err != nil {
- log.Fatal("Error initializing bifrost:", err)
return nil, err
}