Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package bifrost
import (
"bifrost/interfaces"
"bifrost/providers"
"context"
"fmt"
"math/rand"
"os"
Expand Down Expand Up @@ -218,13 +219,12 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
responseChan := make(chan *interfaces.CompletionResult)
errorChan := make(chan error)

for _, plugin := range bifrost.plugins {
if req.PluginParams == nil {
req.PluginParams = make(map[string]interface{})
}

req, err = plugin.PreHook(req)
// Create a context with timeout same as the provider/request config
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will not create our own context, we shall ask users to pass in the context from the chat/text completion functions. keep it as the first param there. they would pass in their context to the completion functions that we shall use with in the plugin system

defer cancel()

for _, plugin := range bifrost.plugins {
req, err = plugin.PreHook(&ctx, req)
if err != nil {
return nil, err
}
Expand All @@ -239,10 +239,8 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo

select {
case result := <-responseChan:
result.PluginParams = req.PluginParams

for _, plugin := range bifrost.plugins {
result, err = plugin.PostHook(result)
result, err = plugin.PostHook(&ctx, result)

if err != nil {
return nil, err
Expand All @@ -252,6 +250,8 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
return result, nil
case err := <-errorChan:
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
}
}

Expand All @@ -264,13 +264,12 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
responseChan := make(chan *interfaces.CompletionResult)
errorChan := make(chan error)

for _, plugin := range bifrost.plugins {
if req.PluginParams == nil {
req.PluginParams = make(map[string]interface{})
}

req, err = plugin.PreHook(req)
// Create a context with timeout same as the provider/request config
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

for _, plugin := range bifrost.plugins {
req, err = plugin.PreHook(&ctx, req)
if err != nil {
return nil, err
}
Expand All @@ -283,13 +282,10 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
Type: ChatCompletionRequest,
}

// Wait for response
select {
case result := <-responseChan:
result.PluginParams = req.PluginParams

for _, plugin := range bifrost.plugins {
result, err = plugin.PostHook(result)
result, err = plugin.PostHook(&ctx, result)

if err != nil {
return nil, err
Expand All @@ -299,6 +295,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
return result, nil
case err := <-errorChan:
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
}
}

Expand Down
13 changes: 7 additions & 6 deletions interfaces/plugin.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
package interfaces

import "context"

type RequestInput struct {
StringInput *string
MessageInput *[]Message
}

type BifrostRequest struct {
Model string
Input RequestInput
Params *ModelParameters
PluginParams map[string]interface{}
Model string
Input RequestInput
Params *ModelParameters
}

type Plugin interface {
PreHook(req *BifrostRequest) (*BifrostRequest, error)
PostHook(result *CompletionResult) (*CompletionResult, error)
PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error)
Copy link
Collaborator

@danpiths danpiths Mar 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us pass in the context object and not a pointer to the context for both the pre and post hook

PostHook(ctx *context.Context, result *CompletionResult) (*CompletionResult, error)
}
1 change: 0 additions & 1 deletion interfaces/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ type CompletionResult struct {
Model string `json:"model"`
Created string `json:"created"`
Params *interface{} `json:"modelParams"`
PluginParams map[string]interface{} `json:"-"`
Trace *struct {
Input interface{} `json:"input"`
Output interface{} `json:"output"`
Expand Down
23 changes: 17 additions & 6 deletions tests/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@ package tests

import (
"bifrost/interfaces"
"context"
"fmt"
"time"

"github.com/maximhq/maxim-go"
"github.com/maximhq/maxim-go/logging"
)

// Define a custom type for context key to avoid collisions
type contextKey string

const (
traceIDKey contextKey = "traceID"
)

type Plugin struct {
logger *logging.Logger
}

func (plugin *Plugin) PreHook(req *interfaces.BifrostRequest) (*interfaces.BifrostRequest, error) {
func (plugin *Plugin) PreHook(ctx *context.Context, req *interfaces.BifrostRequest) (*interfaces.BifrostRequest, error) {
traceID := time.Now().Format("20060102_150405000")

trace := plugin.logger.Trace(&logging.TraceConfig{
Expand All @@ -23,15 +31,18 @@ func (plugin *Plugin) PreHook(req *interfaces.BifrostRequest) (*interfaces.Bifro

trace.SetInput(fmt.Sprintf("New Request Incoming: %v", req))

req.PluginParams["traceID"] = traceID
// Store traceID in context
*ctx = context.WithValue(*ctx, traceIDKey, traceID)

return req, nil
}

func (plugin *Plugin) PostHook(res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) {
fmt.Println(res.PluginParams)

traceID := res.PluginParams["traceID"].(string)
func (plugin *Plugin) PostHook(ctx *context.Context, res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) {
// Get traceID from context
traceID, ok := (*ctx).Value(traceIDKey).(string)
if !ok {
return res, fmt.Errorf("traceID not found in context")
}

plugin.logger.SetTraceOutput(traceID, fmt.Sprintf("Response: %v", res))
return res, nil
Expand Down