Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 80 additions & 20 deletions bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,25 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.Sup
return supportedKeys[0].Value, nil
}

// calculateBackoff implements exponential backoff with jitter
func (bifrost *Bifrost) calculateBackoff(attempt int, config *interfaces.ProviderConfig) time.Duration {
// Calculate an exponential backoff: initial * 2^attempt
backoff := config.NetworkConfig.RetryBackoffInitial * time.Duration(1<<uint(attempt))
if backoff > config.NetworkConfig.RetryBackoffMax {
backoff = config.NetworkConfig.RetryBackoffMax
}

// Add jitter (±20%)
jitter := float64(backoff) * (0.8 + 0.4*rand.Float64())

return time.Duration(jitter)
}

func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan ChannelMessage) {
defer bifrost.waitGroups[provider.GetProviderKey()].Done()

for req := range queue {
var result *interfaces.BifrostResponse
var err error
var bifrostError *interfaces.BifrostError

key, err := bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model)
Expand All @@ -253,35 +266,82 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan
Error: err,
},
}
continue
}

config, err := bifrost.account.GetConfigForProvider(provider.GetProviderKey())
if err != nil {
req.Err <- interfaces.BifrostError{
IsBifrostError: false,
Error: interfaces.ErrorField{
Message: err.Error(),
Error: err,
},
}
continue
}

if req.Type == TextCompletionRequest {
if req.Input.TextCompletionInput == nil {
bifrostError = &interfaces.BifrostError{
IsBifrostError: false,
Error: interfaces.ErrorField{
Message: "text not provided for text completion request",
},
}
} else {
result, bifrostError = provider.TextCompletion(req.Model, key, *req.Input.TextCompletionInput, req.Params)
// Track attempts
var attempts int

// Execute request with retries
for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ {
if attempts > 0 {
// Log retry attempt
bifrost.logger.Info(fmt.Sprintf(
"Retrying request (attempt %d/%d) for model %s: %s",
attempts, config.NetworkConfig.MaxRetries, req.Model,
bifrostError.Error.Message,
))

// Calculate and apply backoff
backoff := bifrost.calculateBackoff(attempts-1, config)
time.Sleep(backoff)
}
} else if req.Type == ChatCompletionRequest {
if req.Input.ChatCompletionInput == nil {
bifrostError = &interfaces.BifrostError{
IsBifrostError: false,
Error: interfaces.ErrorField{
Message: "chats not provided for chat completion request",
},

// Attempt the request
if req.Type == TextCompletionRequest {
if req.Input.TextCompletionInput == nil {
bifrostError = &interfaces.BifrostError{
IsBifrostError: false,
Error: interfaces.ErrorField{
Message: "text not provided for text completion request",
},
}
break // Don't retry client errors
} else {
result, bifrostError = provider.TextCompletion(req.Model, key, *req.Input.TextCompletionInput, req.Params)
}
} else if req.Type == ChatCompletionRequest {
if req.Input.ChatCompletionInput == nil {
bifrostError = &interfaces.BifrostError{
IsBifrostError: false,
Error: interfaces.ErrorField{
Message: "chats not provided for chat completion request",
},
}
break // Don't retry client errors
} else {
result, bifrostError = provider.ChatCompletion(req.Model, key, *req.Input.ChatCompletionInput, req.Params)
}
} else {
result, bifrostError = provider.ChatCompletion(req.Model, key, *req.Input.ChatCompletionInput, req.Params)
}

// Check if successful or if we should retry
if bifrostError == nil ||
//TODO should have a better way to check for only network errors
bifrostError.IsBifrostError || // Only retry non-bifrost errors
attempts == config.NetworkConfig.MaxRetries {
break
}
}

if bifrostError != nil {
// Add retry information to error
if attempts > 0 {
bifrost.logger.Warn(fmt.Sprintf("Request failed after %d %s",
attempts,
map[bool]string{true: "retries", false: "retry"}[attempts > 1]))
}
req.Err <- *bifrostError
} else {
req.Response <- result
Expand Down
7 changes: 6 additions & 1 deletion interfaces/provider.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package interfaces

import "time"

// TODO third party providers

type NetworkConfig struct {
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
MaxRetries int `json:"max_retries"`
RetryBackoffInitial time.Duration `json:"retry_backoff_initial"`
RetryBackoffMax time.Duration `json:"retry_backoff_max"`
}

type MetaConfig struct {
Expand Down
2 changes: 1 addition & 1 deletion providers/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int
// Make request
if err := provider.client.Do(req, resp); err != nil {
return nil, &interfaces.BifrostError{
IsBifrostError: true,
IsBifrostError: false,
Error: interfaces.ErrorField{
Message: ErrOpenAIRequest.Error(),
Error: err,
Expand Down
13 changes: 13 additions & 0 deletions tests/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tests
import (
"fmt"
"os"
"time"

"github.com/maximhq/bifrost/interfaces"

Expand Down Expand Up @@ -64,6 +65,9 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.Supp
return &interfaces.ProviderConfig{
NetworkConfig: interfaces.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
MaxRetries: 3,
RetryBackoffInitial: 100 * time.Millisecond,
RetryBackoffMax: 2 * time.Second,
},
ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{
Concurrency: 3,
Expand All @@ -74,6 +78,9 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.Supp
return &interfaces.ProviderConfig{
NetworkConfig: interfaces.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
MaxRetries: 3,
RetryBackoffInitial: 100 * time.Millisecond,
RetryBackoffMax: 2 * time.Second,
},
ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{
Concurrency: 3,
Expand All @@ -84,6 +91,9 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.Supp
return &interfaces.ProviderConfig{
NetworkConfig: interfaces.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
MaxRetries: 3,
RetryBackoffInitial: 100 * time.Millisecond,
RetryBackoffMax: 2 * time.Second,
},
MetaConfig: &interfaces.MetaConfig{
SecretAccessKey: maxim.StrPtr(os.Getenv("BEDROCK_ACCESS_KEY")),
Expand All @@ -98,6 +108,9 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.Supp
return &interfaces.ProviderConfig{
NetworkConfig: interfaces.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
MaxRetries: 3,
RetryBackoffInitial: 100 * time.Millisecond,
RetryBackoffMax: 2 * time.Second,
},
ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{
Concurrency: 3,
Expand Down