From 6458629ed2d330db78364ad87da13cf9a187942b Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Wed, 9 Apr 2025 09:54:01 +0530 Subject: [PATCH] dev: timing tests --- bifrost.go | 219 +++++++++++++++++++++++++++++++++++++++++-- providers/openai.go | 149 +++++++++++++++++++++++++---- tests/account.go | 4 +- tests/openai_test.go | 91 +++++++++++++++++- 4 files changed, 435 insertions(+), 28 deletions(-) diff --git a/bifrost.go b/bifrost.go index 2d3cd61d5..94421f063 100644 --- a/bifrost.go +++ b/bifrost.go @@ -2,12 +2,15 @@ package bifrost import ( "context" + "encoding/json" "fmt" "math/rand" "os" "os/signal" + "runtime/debug" "slices" "sync" + "sync/atomic" "syscall" "time" @@ -15,6 +18,18 @@ import ( "github.com/maximhq/bifrost/providers" ) +// Metrics to track timing +type RequestMetrics struct { + TotalTime time.Duration `json:"total_time"` + QueueWaitTime time.Duration `json:"queue_wait_time"` + KeySelectionTime time.Duration `json:"key_selection_time"` + ProviderTime time.Duration `json:"provider_time"` + PluginPreTime time.Duration `json:"plugin_pre_time"` + PluginPostTime time.Duration `json:"plugin_post_time"` + RequestCount int64 `json:"request_count"` + ErrorCount int64 `json:"error_count"` +} + type RequestType string const ( @@ -24,9 +39,10 @@ const ( type ChannelMessage struct { interfaces.BifrostRequest - Response chan *interfaces.BifrostResponse - Err chan interfaces.BifrostError - Type RequestType + Response chan *interfaces.BifrostResponse + Err chan interfaces.BifrostError + Type RequestType + Timestamp time.Time } // Bifrost manages providers and maintains infinite open channels @@ -40,6 +56,19 @@ type Bifrost struct { responseChannelPool sync.Pool // Pool for response channels errorChannelPool sync.Pool // Pool for error channels logger interfaces.Logger + metrics RequestMetrics + metricsMutex sync.RWMutex + + // Pool usage counters + channelMessageGets atomic.Int64 + channelMessagePuts atomic.Int64 + channelMessageCreations atomic.Int64 + responseChannelGets atomic.Int64 + responseChannelPuts atomic.Int64 + responseChannelCreations atomic.Int64 + errorChannelGets atomic.Int64 + errorChannelPuts atomic.Int64 + errorChannelCreations atomic.Int64 } func (bifrost *Bifrost) createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) (interfaces.Provider, error) { @@ -95,6 +124,8 @@ func Init(config interfaces.BifrostConfig) (*Bifrost, error) { return nil, fmt.Errorf("account is required to initialize Bifrost") } + debug.SetGCPercent(-1) + bifrost := &Bifrost{ account: config.Account, plugins: config.Plugins, @@ -105,22 +136,25 @@ func Init(config interfaces.BifrostConfig) (*Bifrost, error) { // Initialize object pools bifrost.channelMessagePool = sync.Pool{ New: func() interface{} { + bifrost.channelMessageCreations.Add(1) return &ChannelMessage{} }, } bifrost.responseChannelPool = sync.Pool{ New: func() interface{} { + bifrost.responseChannelCreations.Add(1) return make(chan *interfaces.BifrostResponse, 1) }, } bifrost.errorChannelPool = sync.Pool{ New: func() interface{} { + bifrost.errorChannelCreations.Add(1) return make(chan interfaces.BifrostError, 1) }, } // Prewarm pools with multiple objects - for range config.InitialPoolSize { + for range 2500 { // Create and put new objects directly into pools bifrost.channelMessagePool.Put(&ChannelMessage{}) bifrost.responseChannelPool.Put(make(chan *interfaces.BifrostResponse, 1)) @@ -156,7 +190,10 @@ func Init(config interfaces.BifrostConfig) (*Bifrost, error) { // getChannelMessage gets a ChannelMessage from the pool func (bifrost *Bifrost) getChannelMessage(req interfaces.BifrostRequest, reqType RequestType) *ChannelMessage { // Get channels from pool + bifrost.responseChannelGets.Add(1) responseChan := bifrost.responseChannelPool.Get().(chan *interfaces.BifrostResponse) + + bifrost.errorChannelGets.Add(1) errorChan := bifrost.errorChannelPool.Get().(chan interfaces.BifrostError) // Clear any previous values to avoid leaking between requests @@ -170,11 +207,13 @@ func (bifrost *Bifrost) getChannelMessage(req interfaces.BifrostRequest, reqType } // Get message from pool and configure it + bifrost.channelMessageGets.Add(1) msg := bifrost.channelMessagePool.Get().(*ChannelMessage) msg.BifrostRequest = req msg.Response = responseChan msg.Err = errorChan msg.Type = reqType + msg.Timestamp = time.Now() return msg } @@ -182,12 +221,16 @@ func (bifrost *Bifrost) getChannelMessage(req interfaces.BifrostRequest, reqType // releaseChannelMessage returns a ChannelMessage and its channels to the pool func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { // Put channels back in pools + bifrost.responseChannelPuts.Add(1) bifrost.responseChannelPool.Put(msg.Response) + + bifrost.errorChannelPuts.Add(1) bifrost.errorChannelPool.Put(msg.Err) // Clear references and return to pool msg.Response = nil msg.Err = nil + bifrost.channelMessagePuts.Add(1) bifrost.channelMessagePool.Put(msg) } @@ -254,15 +297,39 @@ func (bifrost *Bifrost) calculateBackoff(attempt int, config *interfaces.Provide return time.Duration(jitter) } +func (bifrost *Bifrost) recordError(queueWaitTime, keySelectTime, providerTime, pluginPreTime, pluginPostTime time.Duration) { + bifrost.metricsMutex.Lock() + defer bifrost.metricsMutex.Unlock() + + atomic.AddInt64(&bifrost.metrics.RequestCount, 1) + atomic.AddInt64(&bifrost.metrics.ErrorCount, 1) + + totalTime := queueWaitTime + keySelectTime + providerTime + pluginPreTime + pluginPostTime + bifrost.metrics.QueueWaitTime = (bifrost.metrics.QueueWaitTime*time.Duration(bifrost.metrics.RequestCount-1) + queueWaitTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.KeySelectionTime = (bifrost.metrics.KeySelectionTime*time.Duration(bifrost.metrics.RequestCount-1) + keySelectTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.ProviderTime = (bifrost.metrics.ProviderTime*time.Duration(bifrost.metrics.RequestCount-1) + providerTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.PluginPreTime = (bifrost.metrics.PluginPreTime*time.Duration(bifrost.metrics.RequestCount-1) + pluginPreTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.PluginPostTime = (bifrost.metrics.PluginPostTime*time.Duration(bifrost.metrics.RequestCount-1) + pluginPostTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.TotalTime = (bifrost.metrics.TotalTime*time.Duration(bifrost.metrics.RequestCount-1) + totalTime) / time.Duration(bifrost.metrics.RequestCount) +} + func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan ChannelMessage) { defer bifrost.waitGroups[provider.GetProviderKey()].Done() for req := range queue { + startTime := time.Now() + queueWaitTime := startTime.Sub(req.Timestamp) + var result *interfaces.BifrostResponse var bifrostError *interfaces.BifrostError + // Track key selection time + keySelectStart := time.Now() key, err := bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) + keySelectTime := time.Since(keySelectStart) + if err != nil { + bifrost.recordError(queueWaitTime, keySelectTime, 0, 0, 0) req.Err <- interfaces.BifrostError{ IsBifrostError: false, Error: interfaces.ErrorField{ @@ -288,6 +355,9 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan // Track attempts var attempts int + // Track provider processing time + providerStart := time.Now() + // Execute request with retries for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ { if attempts > 0 { @@ -339,6 +409,11 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan } } + providerTime := time.Since(providerStart) + + totalTime := time.Since(startTime) + bifrost.recordMetrics(queueWaitTime, keySelectTime, providerTime, 0, 0, totalTime, bifrostError == nil) + if bifrostError != nil { // Add retry information to error if attempts > 0 { @@ -355,6 +430,29 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan bifrost.logger.Debug(fmt.Sprintf("Worker for provider %s exiting...", provider.GetProviderKey())) } +func (bifrost *Bifrost) recordMetrics(queueWaitTime, keySelectTime, providerTime, pluginPreTime, pluginPostTime, totalTime time.Duration, success bool) { + bifrost.metricsMutex.Lock() + defer bifrost.metricsMutex.Unlock() + + atomic.AddInt64(&bifrost.metrics.RequestCount, 1) + if !success { + atomic.AddInt64(&bifrost.metrics.ErrorCount, 1) + } + + bifrost.metrics.QueueWaitTime = (bifrost.metrics.QueueWaitTime*time.Duration(bifrost.metrics.RequestCount-1) + queueWaitTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.KeySelectionTime = (bifrost.metrics.KeySelectionTime*time.Duration(bifrost.metrics.RequestCount-1) + keySelectTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.ProviderTime = (bifrost.metrics.ProviderTime*time.Duration(bifrost.metrics.RequestCount-1) + providerTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.PluginPreTime = (bifrost.metrics.PluginPreTime*time.Duration(bifrost.metrics.RequestCount-1) + pluginPreTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.PluginPostTime = (bifrost.metrics.PluginPostTime*time.Duration(bifrost.metrics.RequestCount-1) + pluginPostTime) / time.Duration(bifrost.metrics.RequestCount) + bifrost.metrics.TotalTime = (bifrost.metrics.TotalTime*time.Duration(bifrost.metrics.RequestCount-1) + totalTime) / time.Duration(bifrost.metrics.RequestCount) +} + +func (bifrost *Bifrost) GetMetrics() RequestMetrics { + bifrost.metricsMutex.RLock() + defer bifrost.metricsMutex.RUnlock() + return bifrost.metrics +} + func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key interfaces.SupportedModelProvider) (interfaces.Provider, error) { for _, provider := range bifrost.providers { if provider.GetProviderKey() == key { @@ -386,6 +484,8 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr } func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, *interfaces.BifrostError) { + startTime := time.Now() + if req == nil { return nil, &interfaces.BifrostError{ IsBifrostError: false, @@ -395,7 +495,11 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo } } + // Track queue acquisition time + queueStart := time.Now() queue, err := bifrost.GetProviderQueue(providerKey) + queueTime := time.Since(queueStart) + if err != nil { return nil, &interfaces.BifrostError{ IsBifrostError: false, @@ -405,6 +509,8 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo } } + // Track plugin pre-hook time + pluginPreStart := time.Now() for _, plugin := range bifrost.plugins { req, err = plugin.PreHook(&ctx, req) if err != nil { @@ -416,6 +522,7 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo } } } + pluginPreTime := time.Since(pluginPreStart) if req == nil { return nil, &interfaces.BifrostError{ @@ -434,7 +541,8 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo var result *interfaces.BifrostResponse select { case result = <-msg.Response: - // Run plugins in reverse order + // Track plugin post-hook time + pluginPostStart := time.Now() for i := len(bifrost.plugins) - 1; i >= 0; i-- { result, err = bifrost.plugins[i].PostHook(&ctx, result) if err != nil { @@ -447,17 +555,31 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo } } } + pluginPostTime := time.Since(pluginPostStart) + totalTime := time.Since(startTime) + bifrost.recordMetrics(0, 0, 0, pluginPreTime, pluginPostTime, totalTime, true) + case err := <-msg.Err: bifrost.releaseChannelMessage(msg) + totalTime := time.Since(startTime) + bifrost.recordMetrics(queueTime, 0, 0, pluginPreTime, 0, totalTime, false) return nil, &err } + // Add bifrost metrics to the response + if rawResponse, ok := result.ExtraFields.RawResponse.(map[string]interface{}); ok { + rawResponse["bifrost_timings"] = bifrost.GetMetrics() + result.ExtraFields.RawResponse = rawResponse + } + // Return message to pool bifrost.releaseChannelMessage(msg) return result, nil } func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, *interfaces.BifrostError) { + startTime := time.Now() + if req == nil { return nil, &interfaces.BifrostError{ IsBifrostError: false, @@ -467,7 +589,11 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo } } + // Track queue acquisition time + queueStart := time.Now() queue, err := bifrost.GetProviderQueue(providerKey) + queueTime := time.Since(queueStart) + if err != nil { return nil, &interfaces.BifrostError{ IsBifrostError: false, @@ -477,6 +603,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo } } + // Track plugin pre-hook time + pluginPreStart := time.Now() for _, plugin := range bifrost.plugins { req, err = plugin.PreHook(&ctx, req) if err != nil { @@ -488,6 +616,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo } } } + pluginPreTime := time.Since(pluginPreStart) if req == nil { return nil, &interfaces.BifrostError{ @@ -506,7 +635,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo var result *interfaces.BifrostResponse select { case result = <-msg.Response: - // Run plugins in reverse order + // Track plugin post-hook time + pluginPostStart := time.Now() for i := len(bifrost.plugins) - 1; i >= 0; i-- { result, err = bifrost.plugins[i].PostHook(&ctx, result) if err != nil { @@ -519,21 +649,98 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo } } } + pluginPostTime := time.Since(pluginPostStart) + totalTime := time.Since(startTime) + bifrost.recordMetrics(0, 0, 0, pluginPreTime, pluginPostTime, totalTime, true) case err := <-msg.Err: bifrost.releaseChannelMessage(msg) + totalTime := time.Since(startTime) + bifrost.recordMetrics(queueTime, 0, 0, pluginPreTime, 0, totalTime, false) return nil, &err } + // Add bifrost metrics to the response + if rawResponse, ok := result.ExtraFields.RawResponse.(map[string]interface{}); ok { + rawResponse["bifrost_timings"] = bifrost.GetMetrics() + result.ExtraFields.RawResponse = rawResponse + } + // Return message to pool bifrost.releaseChannelMessage(msg) return result, nil } +// GetAllStats returns all statistics including request metrics and pool usage +func (bifrost *Bifrost) GetAllStats() map[string]interface{} { + stats := make(map[string]interface{}) + + // Add request metrics + metrics := bifrost.GetMetrics() + stats["request_metrics"] = map[string]interface{}{ + "total_time": metrics.TotalTime.String(), + "queue_wait_time": metrics.QueueWaitTime.String(), + "key_selection_time": metrics.KeySelectionTime.String(), + "provider_time": metrics.ProviderTime.String(), + "plugin_pre_time": metrics.PluginPreTime.String(), + "plugin_post_time": metrics.PluginPostTime.String(), + "request_count": metrics.RequestCount, + "error_count": metrics.ErrorCount, + "error_rate": fmt.Sprintf("%.2f%%", float64(metrics.ErrorCount)/float64(metrics.RequestCount)*100), + } + + // Add pool usage statistics + stats["pool_stats"] = bifrost.GetPoolStats() + + return stats +} + +// GetPoolStats returns statistics about object pool usage +func (bifrost *Bifrost) GetPoolStats() map[string]interface{} { + stats := make(map[string]interface{}) + + // Add channel message pool stats + stats["channel_message_pool"] = map[string]int64{ + "gets": bifrost.channelMessageGets.Load(), + "puts": bifrost.channelMessagePuts.Load(), + "creations": bifrost.channelMessageCreations.Load(), + } + + // Add response channel pool stats + stats["response_channel_pool"] = map[string]int64{ + "gets": bifrost.responseChannelGets.Load(), + "puts": bifrost.responseChannelPuts.Load(), + "creations": bifrost.responseChannelCreations.Load(), + } + + // Add error channel pool stats + stats["error_channel_pool"] = map[string]int64{ + "gets": bifrost.errorChannelGets.Load(), + "puts": bifrost.errorChannelPuts.Load(), + "creations": bifrost.errorChannelCreations.Load(), + } + + // Add provider-specific pool stats + providerStats := providers.GetPoolStats() + for k, v := range providerStats { + stats[k] = v + } + + return stats +} + // Shutdown gracefully stops all workers when triggered func (bifrost *Bifrost) Shutdown() { bifrost.logger.Info("[BIFROST] Graceful Shutdown Initiated - Closing all request channels...") + stats := bifrost.GetAllStats() + statsJSON, err := json.MarshalIndent(stats, "", " ") + if err != nil { + bifrost.logger.Info(fmt.Sprintf("[BIFROST] Stats collection failed: %v", err)) + } else { + bifrost.logger.Info(fmt.Sprintf("[BIFROST] Statistics:\n%s", statsJSON)) + } + // Close all provider queues to signal workers to stop for _, queue := range bifrost.requestQueues { close(queue) diff --git a/providers/openai.go b/providers/openai.go index 2d17377f0..8082bedd7 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -3,10 +3,13 @@ package providers import ( "encoding/json" "fmt" + "math/rand/v2" "sync" + "sync/atomic" "time" "github.com/maximhq/bifrost/interfaces" + "github.com/maximhq/maxim-go" "github.com/valyala/fasthttp" ) @@ -20,9 +23,33 @@ var ( ErrOpenAIDecompress = fmt.Errorf("error decompressing OpenAI response") ) +// Counters for pool usage +var ( + openAIPoolGets atomic.Int64 + openAIPoolPuts atomic.Int64 + openAIPoolCreations atomic.Int64 + + bifrostPoolGets atomic.Int64 + bifrostPoolPuts atomic.Int64 + bifrostPoolCreations atomic.Int64 +) + +// GetPoolStats returns the current pool usage statistics +func GetPoolStats() map[string]int64 { + return map[string]int64{ + "openai_pool_gets": openAIPoolGets.Load(), + "openai_pool_puts": openAIPoolPuts.Load(), + "openai_pool_creations": openAIPoolCreations.Load(), + "bifrost_pool_gets": bifrostPoolGets.Load(), + "bifrost_pool_puts": bifrostPoolPuts.Load(), + "bifrost_pool_creations": bifrostPoolCreations.Load(), + } +} + // OpenAIResponsePool provides a pool for OpenAI response objects var openAIResponsePool = sync.Pool{ New: func() interface{} { + openAIPoolCreations.Add(1) return &OpenAIResponse{} }, } @@ -30,12 +57,14 @@ var openAIResponsePool = sync.Pool{ // BifrostResponsePool provides a pool for Bifrost response objects var bifrostResponsePool = sync.Pool{ New: func() interface{} { + bifrostPoolCreations.Add(1) return &interfaces.BifrostResponse{} }, } // AcquireOpenAIResponse gets an OpenAI response from the pool func AcquireOpenAIResponse() *OpenAIResponse { + openAIPoolGets.Add(1) resp := openAIResponsePool.Get().(*OpenAIResponse) *resp = OpenAIResponse{} // Reset the struct return resp @@ -44,12 +73,14 @@ func AcquireOpenAIResponse() *OpenAIResponse { // ReleaseOpenAIResponse returns an OpenAI response to the pool func ReleaseOpenAIResponse(resp *OpenAIResponse) { if resp != nil { + openAIPoolPuts.Add(1) openAIResponsePool.Put(resp) } } // AcquireBifrostResponse gets a Bifrost response from the pool func AcquireBifrostResponse() *interfaces.BifrostResponse { + bifrostPoolGets.Add(1) resp := bifrostResponsePool.Get().(*interfaces.BifrostResponse) *resp = interfaces.BifrostResponse{} // Reset the struct return resp @@ -58,6 +89,7 @@ func AcquireBifrostResponse() *interfaces.BifrostResponse { // ReleaseBifrostResponse returns a Bifrost response to the pool func ReleaseBifrostResponse(resp *interfaces.BifrostResponse) { if resp != nil { + bifrostPoolPuts.Add(1) bifrostResponsePool.Put(resp) } } @@ -87,8 +119,9 @@ type OpenAIError struct { // OpenAIProvider implements the Provider interface for OpenAI type OpenAIProvider struct { - logger interfaces.Logger - client *fasthttp.Client + MockResponse bool + logger interfaces.Logger + client *fasthttp.Client } // NewOpenAIProvider creates a new OpenAI provider instance @@ -100,8 +133,8 @@ func NewOpenAIProvider(config *interfaces.ProviderConfig, logger interfaces.Logg MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, } + // Prewarm pools directly without affecting counters for range config.ConcurrencyAndBufferSize.Concurrency { - // Create and put new objects directly into pools openAIResponsePool.Put(&OpenAIResponse{}) bifrostResponsePool.Put(&interfaces.BifrostResponse{}) } @@ -110,8 +143,9 @@ func NewOpenAIProvider(config *interfaces.ProviderConfig, logger interfaces.Logg client = configureProxy(client, config.ProxyConfig, logger) return &OpenAIProvider{ - logger: logger, - client: client, + MockResponse: false, + logger: logger, + client: client, } } @@ -130,6 +164,10 @@ func (provider *OpenAIProvider) TextCompletion(model, key, text string, params * } func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { + timings := make(map[string]time.Duration) + + // Track message formatting time + formatStart := time.Now() // Format messages for OpenAI API var formattedMessages []map[string]interface{} for _, msg := range messages { @@ -168,14 +206,23 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int }) } } + timings["message_formatting"] = time.Since(formatStart) + // Track params preparation time + paramsStart := time.Now() preparedParams := PrepareParams(params) + timings["params_preparation"] = time.Since(paramsStart) + // Track request body preparation time + bodyStart := time.Now() requestBody := MergeConfig(map[string]interface{}{ "model": model, "messages": formattedMessages, }, preparedParams) + timings["request_body_preparation"] = time.Since(bodyStart) + // Track JSON marshaling time + marshalStart := time.Now() jsonBody, err := json.Marshal(requestBody) if err != nil { return nil, &interfaces.BifrostError{ @@ -186,8 +233,10 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int }, } } + timings["json_marshaling"] = time.Since(marshalStart) - // Create request + // Track request setup time + setupStart := time.Now() req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) @@ -198,19 +247,44 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int req.Header.SetContentType("application/json") req.Header.Set("Authorization", "Bearer "+key) req.SetBody(jsonBody) + timings["request_setup"] = time.Since(setupStart) + + // Track HTTP request time + httpStart := time.Now() + + var shouldMakeRealCall bool = true + if provider.MockResponse { + // Try mock response first + if mockResponse := mockOpenAIChatCompletionResponse(req, model); mockResponse != nil { + // Copy the mock response body to the real response + resp.SetBody(mockResponse) + // Simulate network delay + jitter := time.Duration(float64(1500*time.Millisecond) * (0.6 + 0.8*rand.Float64())) + time.Sleep(jitter) + shouldMakeRealCall = false + } else { + // Log that we're falling back to real API call due to mock failure + provider.logger.Debug("Mock response generation failed, falling back to real API call") + } + } - // Make request - if err := provider.client.Do(req, resp); err != nil { - return nil, &interfaces.BifrostError{ - IsBifrostError: false, - Error: interfaces.ErrorField{ - Message: ErrOpenAIRequest.Error(), - Error: err, - }, + if shouldMakeRealCall { + // Make the real API call + if err := provider.client.Do(req, resp); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: ErrOpenAIRequest.Error(), + Error: err, + }, + } } } - // Handle error response + timings["http_request"] = time.Since(httpStart) + + // Track error handling time + errorStart := time.Now() if resp.StatusCode() != fasthttp.StatusOK { var errorResp OpenAIError if err := json.Unmarshal(resp.Body(), &errorResp); err != nil { @@ -238,9 +312,12 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int }, } } + timings["error_handling"] = time.Since(errorStart) responseBody := resp.Body() + // Track response parsing time + parseStart := time.Now() // Pre-allocate response structs from pools openAIResponse := AcquireOpenAIResponse() defer ReleaseOpenAIResponse(openAIResponse) @@ -284,6 +361,7 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int }, } } + timings["response_parsing"] = time.Since(parseStart) // Populate result from response result.ID = openAIResponse.ID @@ -295,11 +373,48 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int result.Created = openAIResponse.Created result.Model = openAIResponse.Model result.ExtraFields = interfaces.BifrostResponseExtraFields{ - Provider: interfaces.OpenAI, - RawResponse: rawResponse, + Provider: interfaces.OpenAI, + RawResponse: map[string]interface{}{ + "response": rawResponse, + "timings": timings, + }, } ReleaseBifrostResponse(result) return result, nil } + +// mockOpenAIResponse creates a mock response for OpenAI API calls +func mockOpenAIChatCompletionResponse(req *fasthttp.Request, model string) []byte { + // Create a mock response that mimics OpenAI's format + mockResp := &OpenAIResponse{ + ID: "mock-" + model + "-" + fmt.Sprintf("%d", time.Now().Unix()), + Object: "chat.completion", + Model: model, + Created: int(time.Now().Unix()), + Choices: []interfaces.BifrostResponseChoice{ + { + Index: 0, + Message: interfaces.BifrostResponseChoiceMessage{ + Role: interfaces.RoleAssistant, + Content: maxim.StrPtr("This is a mock response from the Bifrost API gateway. The actual API was not called."), + }, + FinishReason: maxim.StrPtr("stop"), + }, + }, + Usage: interfaces.LLMUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + }, + } + + // Convert to JSON + mockJSON, err := json.Marshal(mockResp) + if err != nil { + return nil + } + + return mockJSON +} diff --git a/tests/account.go b/tests/account.go index f4df02a2a..dab8d25b5 100644 --- a/tests/account.go +++ b/tests/account.go @@ -70,8 +70,8 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.Supp RetryBackoffMax: 2 * time.Second, }, ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, + Concurrency: 1000, + BufferSize: 1000, }, }, nil case interfaces.Anthropic: diff --git a/tests/openai_test.go b/tests/openai_test.go index 655d2d98a..8c9538dfd 100644 --- a/tests/openai_test.go +++ b/tests/openai_test.go @@ -160,7 +160,7 @@ func setupOpenAIToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { for i, message := range openAIMessages { delay := time.Duration(100*(i+1)) * time.Millisecond go func(msg string, delay time.Duration, index int) { - time.Sleep(delay) + // time.Sleep(delay) messages := []interfaces.Message{ { Role: interfaces.RoleUser, @@ -177,8 +177,12 @@ func setupOpenAIToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { if err != nil { fmt.Printf("Error in OpenAI tool call request %d: %v\n", index+1, err.Error.Message) } else { - toolCall := *result.Choices[0].Message.ToolCalls - fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments) + toolCall := result.Choices[0].Message.ToolCalls + if toolCall != nil && len(*toolCall) > 0 { + fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, (*toolCall)[0].Function.Arguments) + } else { + fmt.Printf("🐒 Tool Call Result %d: No tool call found\n", index+1) + } } }(message, delay, i) } @@ -195,3 +199,84 @@ func TestOpenAI(t *testing.T) { bifrost.Cleanup() } + +// TestOpenAILoadTest simulates 10,000 requests with round-robin distribution +func TestOpenAILoadTest(t *testing.T) { + bifrost, err := getBifrost() + if err != nil { + t.Fatalf("Error initializing bifrost: %v", err) + return + } + + // Sample messages for round-robin distribution + openAIMessages := []string{ + "Hello! How are you today?", + "Tell me a joke!", + "What's your favorite programming language?", + "Explain quantum computing in simple terms.", + "What are the best practices for writing clean code?", + } + + // Channel to track completion of all requests + done := make(chan bool) + ctx := context.Background() + totalRequests := 25 + completedRequests := 0 + droppedRequests := 0 + + // Start time tracking + startTime := time.Now() + + // Launch 100,000 requests + for i := 0; i < totalRequests; i++ { + // Round-robin message selection + message := openAIMessages[i%len(openAIMessages)] + + go func(msg string, index int) { + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + + _, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ + Model: "gpt-4o-mini", + Input: interfaces.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: nil, + }, ctx) + + if err != nil { + fmt.Printf("Error in OpenAI request %d: %v\n", index+1, err.Error.Message) + droppedRequests++ + } else { + t.Logf("Request %d completed successfully", index+1) + } + + // Track completion + completedRequests++ + if completedRequests == totalRequests { + done <- true + } + + if completedRequests%10 == 0 { + fmt.Printf("Completed %d requests, dropped %d requests\n", completedRequests, droppedRequests) + } + + }(message, i) + } + + // Wait for all requests to complete or timeout after 5 minutes + select { + case <-done: + elapsed := time.Since(startTime) + t.Logf("All %d requests completed in %v", totalRequests, elapsed) + t.Logf("Average request time: %v", elapsed/time.Duration(totalRequests)) + case <-time.After(5 * time.Minute): + t.Errorf("Test timed out after 5 minutes. Completed %d/%d requests", completedRequests, totalRequests) + } + + bifrost.Cleanup() +}