From cfbee137d3576e713e6670eb19d94477e996ee5e Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Sat, 22 Mar 2025 00:54:26 +0530 Subject: [PATCH] feat: plugin context added --- bifrost.go | 36 +++++++++++++++++------------------- interfaces/plugin.go | 13 +++++++------ interfaces/provider.go | 1 - tests/plugin.go | 23 +++++++++++++++++------ 4 files changed, 41 insertions(+), 32 deletions(-) diff --git a/bifrost.go b/bifrost.go index 3e95a2c6b..6e4a612fd 100644 --- a/bifrost.go +++ b/bifrost.go @@ -3,6 +3,7 @@ package bifrost import ( "bifrost/interfaces" "bifrost/providers" + "context" "fmt" "math/rand" "os" @@ -218,13 +219,12 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo responseChan := make(chan *interfaces.CompletionResult) errorChan := make(chan error) - for _, plugin := range bifrost.plugins { - if req.PluginParams == nil { - req.PluginParams = make(map[string]interface{}) - } - - req, err = plugin.PreHook(req) + // Create a context with timeout same as the provider/request config + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + for _, plugin := range bifrost.plugins { + req, err = plugin.PreHook(&ctx, req) if err != nil { return nil, err } @@ -239,10 +239,8 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo select { case result := <-responseChan: - result.PluginParams = req.PluginParams - for _, plugin := range bifrost.plugins { - result, err = plugin.PostHook(result) + result, err = plugin.PostHook(&ctx, result) if err != nil { return nil, err @@ -252,6 +250,8 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo return result, nil case err := <-errorChan: return nil, err + case <-ctx.Done(): + return nil, ctx.Err() } } @@ -264,13 +264,12 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo responseChan := make(chan *interfaces.CompletionResult) errorChan := make(chan error) - for _, plugin := range bifrost.plugins { - if req.PluginParams == nil { - req.PluginParams = make(map[string]interface{}) - } - - req, err = plugin.PreHook(req) + // Create a context with timeout same as the provider/request config + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + for _, plugin := range bifrost.plugins { + req, err = plugin.PreHook(&ctx, req) if err != nil { return nil, err } @@ -283,13 +282,10 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo Type: ChatCompletionRequest, } - // Wait for response select { case result := <-responseChan: - result.PluginParams = req.PluginParams - for _, plugin := range bifrost.plugins { - result, err = plugin.PostHook(result) + result, err = plugin.PostHook(&ctx, result) if err != nil { return nil, err @@ -299,6 +295,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo return result, nil case err := <-errorChan: return nil, err + case <-ctx.Done(): + return nil, ctx.Err() } } diff --git a/interfaces/plugin.go b/interfaces/plugin.go index b983ca17c..2a19dba35 100644 --- a/interfaces/plugin.go +++ b/interfaces/plugin.go @@ -1,18 +1,19 @@ package interfaces +import "context" + type RequestInput struct { StringInput *string MessageInput *[]Message } type BifrostRequest struct { - Model string - Input RequestInput - Params *ModelParameters - PluginParams map[string]interface{} + Model string + Input RequestInput + Params *ModelParameters } type Plugin interface { - PreHook(req *BifrostRequest) (*BifrostRequest, error) - PostHook(result *CompletionResult) (*CompletionResult, error) + PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) + PostHook(ctx *context.Context, result *CompletionResult) (*CompletionResult, error) } diff --git a/interfaces/provider.go b/interfaces/provider.go index 495e706ae..36d7ac094 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -125,7 +125,6 @@ type CompletionResult struct { Model string `json:"model"` Created string `json:"created"` Params *interface{} `json:"modelParams"` - PluginParams map[string]interface{} `json:"-"` Trace *struct { Input interface{} `json:"input"` Output interface{} `json:"output"` diff --git a/tests/plugin.go b/tests/plugin.go index 8ceaad20e..2698df075 100644 --- a/tests/plugin.go +++ b/tests/plugin.go @@ -2,6 +2,7 @@ package tests import ( "bifrost/interfaces" + "context" "fmt" "time" @@ -9,11 +10,18 @@ import ( "github.com/maximhq/maxim-go/logging" ) +// Define a custom type for context key to avoid collisions +type contextKey string + +const ( + traceIDKey contextKey = "traceID" +) + type Plugin struct { logger *logging.Logger } -func (plugin *Plugin) PreHook(req *interfaces.BifrostRequest) (*interfaces.BifrostRequest, error) { +func (plugin *Plugin) PreHook(ctx *context.Context, req *interfaces.BifrostRequest) (*interfaces.BifrostRequest, error) { traceID := time.Now().Format("20060102_150405000") trace := plugin.logger.Trace(&logging.TraceConfig{ @@ -23,15 +31,18 @@ func (plugin *Plugin) PreHook(req *interfaces.BifrostRequest) (*interfaces.Bifro trace.SetInput(fmt.Sprintf("New Request Incoming: %v", req)) - req.PluginParams["traceID"] = traceID + // Store traceID in context + *ctx = context.WithValue(*ctx, traceIDKey, traceID) return req, nil } -func (plugin *Plugin) PostHook(res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) { - fmt.Println(res.PluginParams) - - traceID := res.PluginParams["traceID"].(string) +func (plugin *Plugin) PostHook(ctx *context.Context, res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) { + // Get traceID from context + traceID, ok := (*ctx).Value(traceIDKey).(string) + if !ok { + return res, fmt.Errorf("traceID not found in context") + } plugin.logger.SetTraceOutput(traceID, fmt.Sprintf("Response: %v", res)) return res, nil