Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 120 additions & 56 deletions bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ type ChannelMessage struct {

// Bifrost manages providers and maintains infinite open channels
type Bifrost struct {
account interfaces.Account
providers []interfaces.Provider // list of processed providers
plugins []interfaces.Plugin
requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues
waitGroups map[interfaces.SupportedModelProvider]*sync.WaitGroup
logger interfaces.Logger
account interfaces.Account
providers []interfaces.Provider // list of processed providers
plugins []interfaces.Plugin
requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues
waitGroups map[interfaces.SupportedModelProvider]*sync.WaitGroup
channelMessagePool sync.Pool // Pool for ChannelMessage objects
responseChannelPool sync.Pool // Pool for response channels
errorChannelPool sync.Pool // Pool for error channels
logger interfaces.Logger
}

func (bifrost *Bifrost) createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) (interfaces.Provider, error) {
Expand Down Expand Up @@ -88,8 +91,37 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro

// Initializes infinite listening channels for each provider
func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interfaces.Logger) (*Bifrost, error) {
bifrost := &Bifrost{account: account, plugins: plugins}
bifrost.waitGroups = make(map[interfaces.SupportedModelProvider]*sync.WaitGroup)
bifrost := &Bifrost{
account: account,
plugins: plugins,
waitGroups: make(map[interfaces.SupportedModelProvider]*sync.WaitGroup),
requestQueues: make(map[interfaces.SupportedModelProvider]chan ChannelMessage),
}

// Initialize object pools
bifrost.channelMessagePool = sync.Pool{
New: func() interface{} {
return &ChannelMessage{}
},
}
bifrost.responseChannelPool = sync.Pool{
New: func() interface{} {
return make(chan *interfaces.BifrostResponse, 1)
},
}
bifrost.errorChannelPool = sync.Pool{
New: func() interface{} {
return make(chan interfaces.BifrostError, 1)
},
}

// Prewarm pools with multiple objects
for range 2500 {
// Create and put new objects directly into pools
bifrost.channelMessagePool.Put(&ChannelMessage{})
bifrost.responseChannelPool.Put(make(chan *interfaces.BifrostResponse, 1))
bifrost.errorChannelPool.Put(make(chan interfaces.BifrostError, 1))
}

providerKeys, err := bifrost.account.GetInitiallyConfiguredProviders()
if err != nil {
Expand All @@ -101,8 +133,6 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interf
}
bifrost.logger = logger

bifrost.requestQueues = make(map[interfaces.SupportedModelProvider]chan ChannelMessage)

// Create buffered channels for each provider and start workers
for _, providerKey := range providerKeys {
config, err := bifrost.account.GetConfigForProvider(providerKey)
Expand All @@ -119,6 +149,44 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interf
return bifrost, nil
}

// getChannelMessage gets a ChannelMessage from the pool
func (bifrost *Bifrost) getChannelMessage(req interfaces.BifrostRequest, reqType RequestType) *ChannelMessage {
// Get channels from pool
responseChan := bifrost.responseChannelPool.Get().(chan *interfaces.BifrostResponse)
errorChan := bifrost.errorChannelPool.Get().(chan interfaces.BifrostError)

// Clear any previous values to avoid leaking between requests
select {
case <-responseChan:
default:
}
select {
case <-errorChan:
default:
}

// Get message from pool and configure it
msg := bifrost.channelMessagePool.Get().(*ChannelMessage)
msg.BifrostRequest = req
msg.Response = responseChan
msg.Err = errorChan
msg.Type = reqType

return msg
}

// releaseChannelMessage returns a ChannelMessage and its channels to the pool
func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {
// Put channels back in pools
bifrost.responseChannelPool.Put(msg.Response)
bifrost.errorChannelPool.Put(msg.Err)

// Clear references and return to pool
msg.Response = nil
msg.Err = nil
bifrost.channelMessagePool.Put(msg)
}

func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.SupportedModelProvider, model string) (string, error) {
keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil {
Expand All @@ -138,37 +206,34 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.Sup
}

if len(supportedKeys) == 0 {
return "", fmt.Errorf("no keys found supporting model: %s", model)
return "", fmt.Errorf("no keys found that support model: %s", model)
}

// Create a new random source
randomSource := rand.New(rand.NewSource(time.Now().UnixNano()))

// Shuffle keys using the new random number generator
randomSource.Shuffle(len(supportedKeys), func(i, j int) {
supportedKeys[i], supportedKeys[j] = supportedKeys[j], supportedKeys[i]
})
if len(supportedKeys) == 1 {
return supportedKeys[0].Value, nil
}

// Compute the cumulative weight sum
var totalWeight float64
// Use a weighted random selection based on key weights
totalWeight := 0
for _, key := range supportedKeys {
totalWeight += key.Weight
totalWeight += int(key.Weight * 100) // Convert float to int for better performance
}

// Generate a random number within total weight
r := randomSource.Float64() * totalWeight
var cumulative float64
// Use a fast random number generator
randomSource := rand.New(rand.NewSource(time.Now().UnixNano()))
randomValue := randomSource.Intn(totalWeight)

// Select the key based on weighted probability
// Select key based on weight
currentWeight := 0
for _, key := range supportedKeys {
cumulative += key.Weight
if r <= cumulative {
currentWeight += int(key.Weight * 100)
if randomValue < currentWeight {
return key.Value, nil
}
}

// Fallback (should never happen)
return supportedKeys[len(supportedKeys)-1].Value, nil
// Fallback to first key if something goes wrong
return supportedKeys[0].Value, nil
}

func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan ChannelMessage) {
Expand Down Expand Up @@ -276,9 +341,6 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
}
}

responseChan := make(chan *interfaces.BifrostResponse)
errorChan := make(chan interfaces.BifrostError)

for _, plugin := range bifrost.plugins {
req, err = plugin.PreHook(&ctx, req)
if err != nil {
Expand All @@ -300,20 +362,19 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
}
}

queue <- ChannelMessage{
BifrostRequest: *req,
Response: responseChan,
Err: errorChan,
Type: TextCompletionRequest,
}
// Get a ChannelMessage from the pool
msg := bifrost.getChannelMessage(*req, TextCompletionRequest)
queue <- *msg

// Handle response
var result *interfaces.BifrostResponse
select {
case result := <-responseChan:
case result = <-msg.Response:
// Run plugins in reverse order
for i := len(bifrost.plugins) - 1; i >= 0; i-- {
result, err = bifrost.plugins[i].PostHook(&ctx, result)

if err != nil {
bifrost.releaseChannelMessage(msg)
return nil, &interfaces.BifrostError{
IsBifrostError: false,
Error: interfaces.ErrorField{
Expand All @@ -322,11 +383,14 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
}
}
}

return result, nil
case err := <-errorChan:
case err := <-msg.Err:
bifrost.releaseChannelMessage(msg)
return nil, &err
}

// Return message to pool
bifrost.releaseChannelMessage(msg)
return result, nil
}

func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, *interfaces.BifrostError) {
Expand All @@ -349,9 +413,6 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
}
}

responseChan := make(chan *interfaces.BifrostResponse)
errorChan := make(chan interfaces.BifrostError)

for _, plugin := range bifrost.plugins {
req, err = plugin.PreHook(&ctx, req)
if err != nil {
Expand All @@ -373,20 +434,19 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
}
}

queue <- ChannelMessage{
BifrostRequest: *req,
Response: responseChan,
Err: errorChan,
Type: ChatCompletionRequest,
}
// Get a ChannelMessage from the pool
msg := bifrost.getChannelMessage(*req, ChatCompletionRequest)
queue <- *msg

// Handle response
var result *interfaces.BifrostResponse
select {
case result := <-responseChan:
case result = <-msg.Response:
// Run plugins in reverse order
for i := len(bifrost.plugins) - 1; i >= 0; i-- {
result, err = bifrost.plugins[i].PostHook(&ctx, result)

if err != nil {
bifrost.releaseChannelMessage(msg)
return nil, &interfaces.BifrostError{
IsBifrostError: false,
Error: interfaces.ErrorField{
Expand All @@ -396,10 +456,14 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
}
}

return result, nil
case err := <-errorChan:
case err := <-msg.Err:
bifrost.releaseChannelMessage(msg)
return nil, &err
}

// Return message to pool
bifrost.releaseChannelMessage(msg)
return result, nil
}

// Shutdown gracefully stops all workers when triggered
Expand Down
10 changes: 10 additions & 0 deletions interfaces/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ type ConcurrencyAndBufferSize struct {
BufferSize int `json:"buffer_size"`
}

// ProxyType defines the type of proxy to use
type ProxyType string

const (
NoProxy ProxyType = "none"
HttpProxy ProxyType = "http"
Socks5Proxy ProxyType = "socks5"
EnvProxy ProxyType = "environment"
)

type ProviderConfig struct {
NetworkConfig NetworkConfig `json:"network_config"`
MetaConfig *MetaConfig `json:"meta_config,omitempty"`
Expand Down
Loading