Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 73 additions & 58 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package bifrost
import (
"context"
"fmt"
"math/rand"
"slices"
"sort"
"strings"
Expand All @@ -15,6 +14,7 @@ import (
"time"

"github.com/google/uuid"
keysortingalgos "github.com/maximhq/bifrost/core/key-sorting-algos"
"github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/azure"
Expand Down Expand Up @@ -70,7 +70,7 @@ type Bifrost struct {
mcpManager *mcp.MCPManager // MCP integration manager (nil if MCP not configured)
mcpInitOnce sync.Once // Ensures MCP manager is initialized only once
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
keySelector schemas.KeySelector // Custom key selector function
keySorter schemas.KeySorterFunc // key sorter function to use
}

// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
Expand Down Expand Up @@ -112,7 +112,6 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
plugins: atomic.Pointer[[]schemas.Plugin]{},
requestQueues: sync.Map{},
waitGroups: sync.Map{},
keySelector: config.KeySelector,
logger: config.Logger,
}
bifrost.plugins.Store(&config.Plugins)
Expand All @@ -122,11 +121,24 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {

bifrost.dropExcessRequests.Store(config.DropExcessRequests)

if bifrost.keySelector == nil {
bifrost.keySelector = WeightedRandomKeySelector
if config.KeySorterAlgorithm == "" {
// Use default key sorter algorithm
config.KeySorterAlgorithm = schemas.KeySorterWeightedRandom
}

// Initialize object pools
switch config.KeySorterAlgorithm {
case schemas.KeySorterWeightedRandom:
bifrost.keySorter = keysortingalgos.WeightedRandomKeySorter
case schemas.KeySorterCustom:
if config.CustomKeySorter == nil {
return nil, fmt.Errorf("custom key sorter function is required when using custom key sorter algorithm")
}
bifrost.keySorter = config.CustomKeySorter
default:
return nil, fmt.Errorf("unsupported key sorter algorithm: %s", config.KeySorterAlgorithm)
}

// Initialize object pool
bifrost.channelMessagePool = sync.Pool{
New: func() interface{} {
return &ChannelMessage{}
Expand Down Expand Up @@ -316,7 +328,7 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
}
}

response, bifrostErr := executeRequestWithRetries(&ctx, config, func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
response, bifrostErr := executeRequestWithRetries(&ctx, config, func(_ int) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return provider.ListModels(ctx, keys, request)
}, schemas.ListModelsRequest, req.Provider, "")
if bifrostErr != nil {
Expand Down Expand Up @@ -2719,13 +2731,43 @@ func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.Bifro
}
}

// selectKeyAndUpdateTriedKeys selects a key based on the attempt number and updates
// the tried keys tracking in the context. It modifies the tried key slices in place
// and returns the selected key.
func selectKeyAndUpdateTriedKeys(
attempts int,
keys []schemas.Key,
triedKeyIDs *[]string,
triedKeyNames *[]string,
ctx *context.Context,
) schemas.Key {
var selectedKey schemas.Key
if len(keys) > 0 {
i := clamp(attempts, 0, len(keys)-1)
selectedKey = keys[i]
*ctx = context.WithValue(*ctx, schemas.BifrostContextKeySelectedKeyID, selectedKey.ID)
*ctx = context.WithValue(*ctx, schemas.BifrostContextKeySelectedKeyName, selectedKey.Name)
if attempts > 0 {
// Compute the previously used key index
prevIndex := clamp(attempts-1, 0, len(keys)-1)
prevKey := keys[prevIndex]
// Append previous key's ID/Name and assign back to accumulate tried keys
*triedKeyIDs = append(*triedKeyIDs, prevKey.ID)
*triedKeyNames = append(*triedKeyNames, prevKey.Name)
*ctx = context.WithValue(*ctx, schemas.BifrostContextKeyTriedKeyIDs, *triedKeyIDs)
*ctx = context.WithValue(*ctx, schemas.BifrostContextKeyTriedKeyNames, *triedKeyNames)
}
}
return selectedKey
}

// executeRequestWithRetries is a generic function that handles common request processing logic
// It consolidates retry logic, backoff calculation, and error handling
// It is not a bifrost method because interface methods in go cannot be generic
func executeRequestWithRetries[T any](
ctx *context.Context,
config *schemas.ProviderConfig,
requestHandler func() (T, *schemas.BifrostError),
requestHandler func(attempts int) (T, *schemas.BifrostError),
requestType schemas.RequestType,
providerKey schemas.ModelProvider,
model string,
Expand Down Expand Up @@ -2758,7 +2800,7 @@ func executeRequestWithRetries[T any](
logger.Debug("attempting %s request for provider %s", requestType, providerKey)

// Attempt the request
result, bifrostError = requestHandler()
result, bifrostError = requestHandler(attempts)

logger.Debug("request %s for provider %s completed", requestType, providerKey)

Expand Down Expand Up @@ -2823,7 +2865,6 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
baseProvider = cfg.BaseProviderType
}

key := schemas.Key{}
var keys []schemas.Key
if providerRequiresKey(baseProvider, config.CustomProviderConfig) {
// Determine if this is a multi-key batch/file operation
Expand Down Expand Up @@ -2855,7 +2896,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
}
} else {
// Use the custom provider name for actual key selection, but pass base provider type for key validation
key, err = bifrost.selectKeyFromProviderForModel(&req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider)
keys, err = bifrost.selectKeyFromProviderForModel(&req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider)
if err != nil {
bifrost.logger.Debug("error selecting key for model %s: %v", model, err)
req.Err <- schemas.BifrostError{
Expand All @@ -2872,8 +2913,6 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
}
continue
}
req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyID, key.ID)
req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyName, key.Name)
}
}
// Create plugin pipeline for streaming requests outside retry loop to prevent leaks
Expand All @@ -2891,13 +2930,17 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
}

// Execute request with retries
var triedKeyIDs []string
var triedKeyNames []string
if IsStreamRequestType(req.RequestType) {
stream, bifrostError = executeRequestWithRetries(&req.Context, config, func() (chan *schemas.BifrostStream, *schemas.BifrostError) {
return bifrost.handleProviderStreamRequest(provider, req, key, postHookRunner)
stream, bifrostError = executeRequestWithRetries(&req.Context, config, func(attempts int) (chan *schemas.BifrostStream, *schemas.BifrostError) {
selectedKey := selectKeyAndUpdateTriedKeys(attempts, keys, &triedKeyIDs, &triedKeyNames, &req.Context)
return bifrost.handleProviderStreamRequest(provider, req, selectedKey, postHookRunner)
}, req.RequestType, provider.GetProviderKey(), model)
} else {
result, bifrostError = executeRequestWithRetries(&req.Context, config, func() (*schemas.BifrostResponse, *schemas.BifrostError) {
return bifrost.handleProviderRequest(provider, req, key, keys)
result, bifrostError = executeRequestWithRetries(&req.Context, config, func(attempts int) (*schemas.BifrostResponse, *schemas.BifrostError) {
selectedKey := selectKeyAndUpdateTriedKeys(attempts, keys, &triedKeyIDs, &triedKeyNames, &req.Context)
return bifrost.handleProviderRequest(provider, req, selectedKey, keys)
}, req.RequestType, provider.GetProviderKey(), model)
}

Expand Down Expand Up @@ -3399,26 +3442,26 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *context.Context, provider

// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
// It uses weighted random selection if multiple keys are available.
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) {
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) ([]schemas.Key, error) {
// Check if key has been set in the context explicitly
if ctx != nil {
key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
if ok {
return key, nil
return []schemas.Key{key}, nil
}
}
// Check if key skipping is allowed
if skipKeySelection, ok := (*ctx).Value(schemas.BifrostContextKeySkipKeySelection).(bool); ok && skipKeySelection && isKeySkippingAllowed(providerKey) {
return schemas.Key{}, nil
return nil, nil
}
// Get keys for provider
keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey)
if err != nil {
return schemas.Key{}, err
return nil, err
}
// Check if no keys found
if len(keys) == 0 {
return schemas.Key{}, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model)
return nil, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model)
}

// For batch API operations, filter keys to only include those with UseForBatchAPI enabled
Expand All @@ -3430,7 +3473,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requ
}
}
if len(batchEnabledKeys) == 0 {
return schemas.Key{}, fmt.Errorf("no config found for batch APIs. Please enable 'Use for Batch APIs' on at least one key for provider: %v", providerKey)
return nil, fmt.Errorf("no config found for batch APIs. Please enable 'Use for Batch APIs' on at least one key for provider: %v", providerKey)
}
keys = batchEnabledKeys
}
Expand Down Expand Up @@ -3486,9 +3529,9 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requ
}
if len(supportedKeys) == 0 {
if baseProviderType == schemas.Azure || baseProviderType == schemas.Bedrock || baseProviderType == schemas.Vertex {
return schemas.Key{}, fmt.Errorf("no keys found that support model/deployment: %s", model)
return nil, fmt.Errorf("no keys found that support model/deployment: %s", model)
}
return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model)
return nil, fmt.Errorf("no keys found that support model: %s", model)
}

var requestedKeyName string
Expand All @@ -3501,47 +3544,19 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requ
if requestedKeyName != "" {
for _, key := range supportedKeys {
if key.Name == requestedKeyName {
return key, nil
return []schemas.Key{key}, nil
}
}
return schemas.Key{}, fmt.Errorf("no key found with name %q for provider: %v", requestedKeyName, providerKey)
return nil, fmt.Errorf("no key found with name %q for provider: %v", requestedKeyName, providerKey)
}

if len(supportedKeys) == 1 {
return supportedKeys[0], nil
}

selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model)
sortedKeys, err := bifrost.keySorter(ctx, supportedKeys, providerKey, model)
if err != nil {
return schemas.Key{}, err
}

return selectedKey, nil

}

func WeightedRandomKeySelector(ctx *context.Context, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
// Use a weighted random selection based on key weights
totalWeight := 0
for _, key := range keys {
totalWeight += int(key.Weight * 100) // Convert float to int for better performance
return nil, err
}

// Use a fast random number generator
randomSource := rand.New(rand.NewSource(time.Now().UnixNano()))
randomValue := randomSource.Intn(totalWeight)

// Select key based on weight
currentWeight := 0
for _, key := range keys {
currentWeight += int(key.Weight * 100)
if randomValue < currentWeight {
return key, nil
}
}
return sortedKeys, nil

// Fallback to first key if something goes wrong
return keys[0], nil
}

// Shutdown gracefully stops all workers when triggered.
Expand Down
12 changes: 6 additions & 6 deletions core/bifrost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) {
// Test immediate success
t.Run("ImmediateSuccess", func(t *testing.T) {
callCount := 0
handler := func() (string, *schemas.BifrostError) {
handler := func(attempts int) (string, *schemas.BifrostError) {
callCount++
return "success", nil
}
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) {
// Test success after retries
t.Run("SuccessAfterRetries", func(t *testing.T) {
callCount := 0
handler := func() (string, *schemas.BifrostError) {
handler := func(attempts int) (string, *schemas.BifrostError) {
callCount++
if callCount <= 2 {
// First two calls fail with retryable error
Expand Down Expand Up @@ -121,7 +121,7 @@ func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) {
ctx := context.Background()
t.Run("ExceedsMaxRetries", func(t *testing.T) {
callCount := 0
handler := func() (string, *schemas.BifrostError) {
handler := func(attempts int) (string, *schemas.BifrostError) {
callCount++
// Always fail with retryable error
return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false)
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
callCount := 0
handler := func() (string, *schemas.BifrostError) {
handler := func(attempts int) (string, *schemas.BifrostError) {
callCount++
return "", tc.error
}
Expand Down Expand Up @@ -256,7 +256,7 @@ func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
callCount := 0
handler := func() (string, *schemas.BifrostError) {
handler := func(attempts int) (string, *schemas.BifrostError) {
callCount++
return "", tc.error
}
Expand Down Expand Up @@ -477,7 +477,7 @@ func TestExecuteRequestWithRetries_LoggingAndCounting(t *testing.T) {
var attemptCounts []int
callCount := 0

handler := func() (string, *schemas.BifrostError) {
handler := func(attempts int) (string, *schemas.BifrostError) {
callCount++
attemptCounts = append(attemptCounts, callCount)

Expand Down
Loading