diff --git a/bifrost.go b/bifrost.go index 155fbfc2f..18fb496e9 100644 --- a/bifrost.go +++ b/bifrost.go @@ -31,12 +31,15 @@ type ChannelMessage struct { // Bifrost manages providers and maintains infinite open channels type Bifrost struct { - account interfaces.Account - providers []interfaces.Provider // list of processed providers - plugins []interfaces.Plugin - requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues - waitGroups map[interfaces.SupportedModelProvider]*sync.WaitGroup - logger interfaces.Logger + account interfaces.Account + providers []interfaces.Provider // list of processed providers + plugins []interfaces.Plugin + requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues + waitGroups map[interfaces.SupportedModelProvider]*sync.WaitGroup + channelMessagePool sync.Pool // Pool for ChannelMessage objects + responseChannelPool sync.Pool // Pool for response channels + errorChannelPool sync.Pool // Pool for error channels + logger interfaces.Logger } func (bifrost *Bifrost) createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) (interfaces.Provider, error) { @@ -88,8 +91,37 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro // Initializes infinite listening channels for each provider func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interfaces.Logger) (*Bifrost, error) { - bifrost := &Bifrost{account: account, plugins: plugins} - bifrost.waitGroups = make(map[interfaces.SupportedModelProvider]*sync.WaitGroup) + bifrost := &Bifrost{ + account: account, + plugins: plugins, + waitGroups: make(map[interfaces.SupportedModelProvider]*sync.WaitGroup), + requestQueues: make(map[interfaces.SupportedModelProvider]chan ChannelMessage), + } + + // Initialize object pools + bifrost.channelMessagePool = sync.Pool{ + New: func() interface{} { + return &ChannelMessage{} + }, + } + bifrost.responseChannelPool = sync.Pool{ + New: func() interface{} { + return make(chan *interfaces.BifrostResponse, 1) + }, + } + bifrost.errorChannelPool = sync.Pool{ + New: func() interface{} { + return make(chan interfaces.BifrostError, 1) + }, + } + + // Prewarm pools with multiple objects + for range 2500 { + // Create and put new objects directly into pools + bifrost.channelMessagePool.Put(&ChannelMessage{}) + bifrost.responseChannelPool.Put(make(chan *interfaces.BifrostResponse, 1)) + bifrost.errorChannelPool.Put(make(chan interfaces.BifrostError, 1)) + } providerKeys, err := bifrost.account.GetInitiallyConfiguredProviders() if err != nil { @@ -101,8 +133,6 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interf } bifrost.logger = logger - bifrost.requestQueues = make(map[interfaces.SupportedModelProvider]chan ChannelMessage) - // Create buffered channels for each provider and start workers for _, providerKey := range providerKeys { config, err := bifrost.account.GetConfigForProvider(providerKey) @@ -119,6 +149,44 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interf return bifrost, nil } +// getChannelMessage gets a ChannelMessage from the pool +func (bifrost *Bifrost) getChannelMessage(req interfaces.BifrostRequest, reqType RequestType) *ChannelMessage { + // Get channels from pool + responseChan := bifrost.responseChannelPool.Get().(chan *interfaces.BifrostResponse) + errorChan := bifrost.errorChannelPool.Get().(chan interfaces.BifrostError) + + // Clear any previous values to avoid leaking between requests + select { + case <-responseChan: + default: + } + select { + case <-errorChan: + default: + } + + // Get message from pool and configure it + msg := bifrost.channelMessagePool.Get().(*ChannelMessage) + msg.BifrostRequest = req + msg.Response = responseChan + msg.Err = errorChan + msg.Type = reqType + + return msg +} + +// releaseChannelMessage returns a ChannelMessage and its channels to the pool +func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { + // Put channels back in pools + bifrost.responseChannelPool.Put(msg.Response) + bifrost.errorChannelPool.Put(msg.Err) + + // Clear references and return to pool + msg.Response = nil + msg.Err = nil + bifrost.channelMessagePool.Put(msg) +} + func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.SupportedModelProvider, model string) (string, error) { keys, err := bifrost.account.GetKeysForProvider(providerKey) if err != nil { @@ -138,37 +206,34 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.Sup } if len(supportedKeys) == 0 { - return "", fmt.Errorf("no keys found supporting model: %s", model) + return "", fmt.Errorf("no keys found that support model: %s", model) } - // Create a new random source - randomSource := rand.New(rand.NewSource(time.Now().UnixNano())) - - // Shuffle keys using the new random number generator - randomSource.Shuffle(len(supportedKeys), func(i, j int) { - supportedKeys[i], supportedKeys[j] = supportedKeys[j], supportedKeys[i] - }) + if len(supportedKeys) == 1 { + return supportedKeys[0].Value, nil + } - // Compute the cumulative weight sum - var totalWeight float64 + // Use a weighted random selection based on key weights + totalWeight := 0 for _, key := range supportedKeys { - totalWeight += key.Weight + totalWeight += int(key.Weight * 100) // Convert float to int for better performance } - // Generate a random number within total weight - r := randomSource.Float64() * totalWeight - var cumulative float64 + // Use a fast random number generator + randomSource := rand.New(rand.NewSource(time.Now().UnixNano())) + randomValue := randomSource.Intn(totalWeight) - // Select the key based on weighted probability + // Select key based on weight + currentWeight := 0 for _, key := range supportedKeys { - cumulative += key.Weight - if r <= cumulative { + currentWeight += int(key.Weight * 100) + if randomValue < currentWeight { return key.Value, nil } } - // Fallback (should never happen) - return supportedKeys[len(supportedKeys)-1].Value, nil + // Fallback to first key if something goes wrong + return supportedKeys[0].Value, nil } func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan ChannelMessage) { @@ -276,9 +341,6 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo } } - responseChan := make(chan *interfaces.BifrostResponse) - errorChan := make(chan interfaces.BifrostError) - for _, plugin := range bifrost.plugins { req, err = plugin.PreHook(&ctx, req) if err != nil { @@ -300,20 +362,19 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo } } - queue <- ChannelMessage{ - BifrostRequest: *req, - Response: responseChan, - Err: errorChan, - Type: TextCompletionRequest, - } + // Get a ChannelMessage from the pool + msg := bifrost.getChannelMessage(*req, TextCompletionRequest) + queue <- *msg + // Handle response + var result *interfaces.BifrostResponse select { - case result := <-responseChan: + case result = <-msg.Response: // Run plugins in reverse order for i := len(bifrost.plugins) - 1; i >= 0; i-- { result, err = bifrost.plugins[i].PostHook(&ctx, result) - if err != nil { + bifrost.releaseChannelMessage(msg) return nil, &interfaces.BifrostError{ IsBifrostError: false, Error: interfaces.ErrorField{ @@ -322,11 +383,14 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo } } } - - return result, nil - case err := <-errorChan: + case err := <-msg.Err: + bifrost.releaseChannelMessage(msg) return nil, &err } + + // Return message to pool + bifrost.releaseChannelMessage(msg) + return result, nil } func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, *interfaces.BifrostError) { @@ -349,9 +413,6 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo } } - responseChan := make(chan *interfaces.BifrostResponse) - errorChan := make(chan interfaces.BifrostError) - for _, plugin := range bifrost.plugins { req, err = plugin.PreHook(&ctx, req) if err != nil { @@ -373,20 +434,19 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo } } - queue <- ChannelMessage{ - BifrostRequest: *req, - Response: responseChan, - Err: errorChan, - Type: ChatCompletionRequest, - } + // Get a ChannelMessage from the pool + msg := bifrost.getChannelMessage(*req, ChatCompletionRequest) + queue <- *msg + // Handle response + var result *interfaces.BifrostResponse select { - case result := <-responseChan: + case result = <-msg.Response: // Run plugins in reverse order for i := len(bifrost.plugins) - 1; i >= 0; i-- { result, err = bifrost.plugins[i].PostHook(&ctx, result) - if err != nil { + bifrost.releaseChannelMessage(msg) return nil, &interfaces.BifrostError{ IsBifrostError: false, Error: interfaces.ErrorField{ @@ -396,10 +456,14 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo } } - return result, nil - case err := <-errorChan: + case err := <-msg.Err: + bifrost.releaseChannelMessage(msg) return nil, &err } + + // Return message to pool + bifrost.releaseChannelMessage(msg) + return result, nil } // Shutdown gracefully stops all workers when triggered diff --git a/interfaces/provider.go b/interfaces/provider.go index bb5b6ddfc..bcbf2f7e1 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -19,6 +19,16 @@ type ConcurrencyAndBufferSize struct { BufferSize int `json:"buffer_size"` } +// ProxyType defines the type of proxy to use +type ProxyType string + +const ( + NoProxy ProxyType = "none" + HttpProxy ProxyType = "http" + Socks5Proxy ProxyType = "socks5" + EnvProxy ProxyType = "environment" +) + type ProviderConfig struct { NetworkConfig NetworkConfig `json:"network_config"` MetaConfig *MetaConfig `json:"meta_config,omitempty"` diff --git a/providers/openai.go b/providers/openai.go index b7f8b44ae..89ff5de6b 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -2,14 +2,68 @@ package providers import ( "encoding/json" - + "fmt" + "sync" "time" "github.com/maximhq/bifrost/interfaces" "github.com/valyala/fasthttp" ) -type OpenAIChatResponse struct { +// Pre-defined errors to reduce allocations in error paths +var ( + ErrOpenAIRequest = fmt.Errorf("error making OpenAI request") + ErrOpenAIResponse = fmt.Errorf("OpenAI error response") + ErrOpenAIJSONMarshaling = fmt.Errorf("error marshaling OpenAI request") + ErrOpenAIDecodeStructured = fmt.Errorf("error decoding OpenAI structured response") + ErrOpenAIDecodeRaw = fmt.Errorf("error decoding OpenAI raw response") + ErrOpenAIDecompress = fmt.Errorf("error decompressing OpenAI response") + ErrOpenAIProxyConfig = fmt.Errorf("invalid proxy configuration") +) + +// OpenAIResponsePool provides a pool for OpenAI response objects +var openAIResponsePool = sync.Pool{ + New: func() interface{} { + return &OpenAIResponse{} + }, +} + +// BifrostResponsePool provides a pool for Bifrost response objects +var bifrostResponsePool = sync.Pool{ + New: func() interface{} { + return &interfaces.BifrostResponse{} + }, +} + +// AcquireOpenAIResponse gets an OpenAI response from the pool +func AcquireOpenAIResponse() *OpenAIResponse { + resp := openAIResponsePool.Get().(*OpenAIResponse) + *resp = OpenAIResponse{} // Reset the struct + return resp +} + +// ReleaseOpenAIResponse returns an OpenAI response to the pool +func ReleaseOpenAIResponse(resp *OpenAIResponse) { + if resp != nil { + openAIResponsePool.Put(resp) + } +} + +// AcquireBifrostResponse gets a Bifrost response from the pool +func AcquireBifrostResponse() *interfaces.BifrostResponse { + resp := bifrostResponsePool.Get().(*interfaces.BifrostResponse) + *resp = interfaces.BifrostResponse{} // Reset the struct + return resp +} + +// ReleaseBifrostResponse returns a Bifrost response to the pool +func ReleaseBifrostResponse(resp *interfaces.BifrostResponse) { + if resp != nil { + bifrostResponsePool.Put(resp) + } +} + +type OpenAIResponse struct { ID string `json:"id"` Object string `json:"object"` // text.completion or chat.completion Choices []interfaces.BifrostResponseChoice `json:"choices"` @@ -40,13 +94,22 @@ type OpenAIProvider struct { // NewOpenAIProvider creates a new OpenAI provider instance func NewOpenAIProvider(config *interfaces.ProviderConfig, logger interfaces.Logger) *OpenAIProvider { + // Create the client + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + for range config.ConcurrencyAndBufferSize.Concurrency { + // Create and put new objects directly into pools + openAIResponsePool.Put(&OpenAIResponse{}) + bifrostResponsePool.Put(&interfaces.BifrostResponse{}) + } + return &OpenAIProvider{ logger: logger, - client: &fasthttp.Client{ - ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, - }, + client: client, } } @@ -116,7 +179,7 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int return nil, &interfaces.BifrostError{ IsBifrostError: true, Error: interfaces.ErrorField{ - Message: "error marshaling request", + Message: ErrOpenAIJSONMarshaling.Error(), Error: err, }, } @@ -139,7 +202,7 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int return nil, &interfaces.BifrostError{ IsBifrostError: true, Error: interfaces.ErrorField{ - Message: "error sending request", + Message: ErrOpenAIRequest.Error(), Error: err, }, } @@ -152,7 +215,7 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int return nil, &interfaces.BifrostError{ IsBifrostError: true, Error: interfaces.ErrorField{ - Message: "error parsing error response", + Message: ErrOpenAIResponse.Error(), Error: err, }, } @@ -174,45 +237,64 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int } } - body := resp.Body() + responseBody := resp.Body() + + // Pre-allocate response structs from pools + openAIResponse := AcquireOpenAIResponse() + defer ReleaseOpenAIResponse(openAIResponse) + + result := AcquireBifrostResponse() - // Decode structured response - var response OpenAIChatResponse - if err := json.Unmarshal(body, &response); err != nil { + // Parallel Unmarshaling of response + var wg sync.WaitGroup + var structuredErr, rawErr error + var rawResponse interface{} + + wg.Add(2) + go func() { + defer wg.Done() + structuredErr = json.Unmarshal(responseBody, openAIResponse) + }() + go func() { + defer wg.Done() + rawErr = json.Unmarshal(responseBody, &rawResponse) + }() + wg.Wait() + + // Check for unmarshaling errors + if structuredErr != nil { + ReleaseBifrostResponse(result) return nil, &interfaces.BifrostError{ IsBifrostError: true, Error: interfaces.ErrorField{ - Message: "error parsing response", - Error: err, + Message: ErrOpenAIDecodeStructured.Error(), + Error: structuredErr, }, } } - - // Decode raw response - var rawResponse interface{} - if err := json.Unmarshal(body, &rawResponse); err != nil { + if rawErr != nil { + ReleaseBifrostResponse(result) return nil, &interfaces.BifrostError{ IsBifrostError: true, Error: interfaces.ErrorField{ - Message: "error parsing raw response", - Error: err, + Message: ErrOpenAIDecodeRaw.Error(), + Error: rawErr, }, } } - result := &interfaces.BifrostResponse{ - ID: response.ID, - Choices: response.Choices, - Object: response.Object, - Usage: response.Usage, - ServiceTier: response.ServiceTier, - SystemFingerprint: response.SystemFingerprint, - Created: response.Created, - Model: response.Model, - ExtraFields: interfaces.BifrostResponseExtraFields{ - Provider: interfaces.OpenAI, - RawResponse: rawResponse, - }, + // Populate result from response + result.ID = openAIResponse.ID + result.Choices = openAIResponse.Choices + result.Object = openAIResponse.Object + result.Usage = openAIResponse.Usage + result.ServiceTier = openAIResponse.ServiceTier + result.SystemFingerprint = openAIResponse.SystemFingerprint + result.Created = openAIResponse.Created + result.Model = openAIResponse.Model + result.ExtraFields = interfaces.BifrostResponseExtraFields{ + Provider: interfaces.OpenAI, + RawResponse: rawResponse, } return result, nil