From ec3dc33f84d13b78ca6ff8d9ce9852e37d6be590 Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Sat, 22 Mar 2025 16:35:55 +0530 Subject: [PATCH] fix: bifrost context fixes --- bifrost.go | 24 ++++++------------------ go.mod | 2 +- interfaces/plugin.go | 4 ++-- tests/anthropic_test.go | 7 +++++-- tests/openai_test.go | 7 +++++-- tests/plugin.go | 10 +++++----- 6 files changed, 24 insertions(+), 30 deletions(-) diff --git a/bifrost.go b/bifrost.go index 6e4a612fd..64cb0f36a 100644 --- a/bifrost.go +++ b/bifrost.go @@ -210,7 +210,7 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr return queue, nil } -func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest) (*interfaces.CompletionResult, error) { +func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) { queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { return nil, err @@ -219,12 +219,8 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo responseChan := make(chan *interfaces.CompletionResult) errorChan := make(chan error) - // Create a context with timeout same as the provider/request config - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - for _, plugin := range bifrost.plugins { - req, err = plugin.PreHook(&ctx, req) + ctx, req, err = plugin.PreHook(ctx, req) if err != nil { return nil, err } @@ -240,7 +236,7 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo select { case result := <-responseChan: for _, plugin := range bifrost.plugins { - result, err = plugin.PostHook(&ctx, result) + result, err = plugin.PostHook(ctx, result) if err != nil { return nil, err @@ -250,12 +246,10 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo return result, nil case err := <-errorChan: return nil, err - case <-ctx.Done(): - return nil, ctx.Err() } } -func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest) (*interfaces.CompletionResult, error) { +func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) { queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { return nil, err @@ -264,12 +258,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo responseChan := make(chan *interfaces.CompletionResult) errorChan := make(chan error) - // Create a context with timeout same as the provider/request config - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - for _, plugin := range bifrost.plugins { - req, err = plugin.PreHook(&ctx, req) + ctx, req, err = plugin.PreHook(ctx, req) if err != nil { return nil, err } @@ -285,7 +275,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo select { case result := <-responseChan: for _, plugin := range bifrost.plugins { - result, err = plugin.PostHook(&ctx, result) + result, err = plugin.PostHook(ctx, result) if err != nil { return nil, err @@ -295,8 +285,6 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo return result, nil case err := <-errorChan: return nil, err - case <-ctx.Done(): - return nil, ctx.Err() } } diff --git a/go.mod b/go.mod index 0145bb251..5f4f228b0 100644 --- a/go.mod +++ b/go.mod @@ -4,4 +4,4 @@ go 1.21.1 require github.com/joho/godotenv v1.5.1 -require github.com/maximhq/maxim-go v0.1.1 // indirect +require github.com/maximhq/maxim-go v0.1.1 diff --git a/interfaces/plugin.go b/interfaces/plugin.go index 2a19dba35..bd322a87d 100644 --- a/interfaces/plugin.go +++ b/interfaces/plugin.go @@ -14,6 +14,6 @@ type BifrostRequest struct { } type Plugin interface { - PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) - PostHook(ctx *context.Context, result *CompletionResult) (*CompletionResult, error) + PreHook(ctx context.Context, req *BifrostRequest) (context.Context, *BifrostRequest, error) + PostHook(ctx context.Context, result *CompletionResult) (*CompletionResult, error) } diff --git a/tests/anthropic_test.go b/tests/anthropic_test.go index d830ed285..9cad8fdb3 100644 --- a/tests/anthropic_test.go +++ b/tests/anthropic_test.go @@ -3,6 +3,7 @@ package tests import ( "bifrost" "bifrost/interfaces" + "context" "fmt" "testing" "time" @@ -17,6 +18,8 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { "Tell me about artificial intelligence.", } + ctx := context.Background() + go func() { params := interfaces.ModelParameters{ ExtraParams: map[string]interface{}{ @@ -31,7 +34,7 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { StringInput: &text, }, Params: ¶ms, - }) + }, ctx) if err != nil { fmt.Println("Error:", err) } else { @@ -61,7 +64,7 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { MessageInput: &messages, }, Params: ¶ms, - }) + }, ctx) if err != nil { fmt.Printf("Error in Anthropic request %d: %v\n", index+1, err) diff --git a/tests/openai_test.go b/tests/openai_test.go index 5c8615de1..7ef30a29d 100644 --- a/tests/openai_test.go +++ b/tests/openai_test.go @@ -3,6 +3,7 @@ package tests import ( "bifrost" "bifrost/interfaces" + "context" "fmt" "testing" "time" @@ -12,6 +13,8 @@ import ( func setupOpenAIRequests(bifrost *bifrost.Bifrost) { text := "Hello world!" + ctx := context.Background() + // Text completion request go func() { result, err := bifrost.TextCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ @@ -20,7 +23,7 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { StringInput: &text, }, Params: nil, - }) + }, ctx) if err != nil { fmt.Println("Error:", err) } else { @@ -52,7 +55,7 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { MessageInput: &messages, }, Params: nil, - }) + }, ctx) if err != nil { fmt.Printf("Error in OpenAI request %d: %v\n", index+1, err) } else { diff --git a/tests/plugin.go b/tests/plugin.go index 2698df075..dace4ff9e 100644 --- a/tests/plugin.go +++ b/tests/plugin.go @@ -21,7 +21,7 @@ type Plugin struct { logger *logging.Logger } -func (plugin *Plugin) PreHook(ctx *context.Context, req *interfaces.BifrostRequest) (*interfaces.BifrostRequest, error) { +func (plugin *Plugin) PreHook(ctx context.Context, req *interfaces.BifrostRequest) (context.Context, *interfaces.BifrostRequest, error) { traceID := time.Now().Format("20060102_150405000") trace := plugin.logger.Trace(&logging.TraceConfig{ @@ -32,14 +32,14 @@ func (plugin *Plugin) PreHook(ctx *context.Context, req *interfaces.BifrostReque trace.SetInput(fmt.Sprintf("New Request Incoming: %v", req)) // Store traceID in context - *ctx = context.WithValue(*ctx, traceIDKey, traceID) + ctx = context.WithValue(ctx, traceIDKey, traceID) - return req, nil + return ctx, req, nil } -func (plugin *Plugin) PostHook(ctx *context.Context, res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) { +func (plugin *Plugin) PostHook(ctx context.Context, res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) { // Get traceID from context - traceID, ok := (*ctx).Value(traceIDKey).(string) + traceID, ok := ctx.Value(traceIDKey).(string) if !ok { return res, fmt.Errorf("traceID not found in context") }