Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr
return queue, nil
}

func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest) (*interfaces.CompletionResult, error) {
func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) {
queue, err := bifrost.GetProviderQueue(providerKey)
if err != nil {
return nil, err
Expand All @@ -219,12 +219,8 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
responseChan := make(chan *interfaces.CompletionResult)
errorChan := make(chan error)

// Create a context with timeout same as the provider/request config
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

for _, plugin := range bifrost.plugins {
req, err = plugin.PreHook(&ctx, req)
ctx, req, err = plugin.PreHook(ctx, req)
if err != nil {
return nil, err
}
Expand All @@ -240,7 +236,7 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
select {
case result := <-responseChan:
for _, plugin := range bifrost.plugins {
result, err = plugin.PostHook(&ctx, result)
result, err = plugin.PostHook(ctx, result)

if err != nil {
return nil, err
Expand All @@ -250,12 +246,10 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
return result, nil
case err := <-errorChan:
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
}
}

func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest) (*interfaces.CompletionResult, error) {
func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) {
queue, err := bifrost.GetProviderQueue(providerKey)
if err != nil {
return nil, err
Expand All @@ -264,12 +258,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
responseChan := make(chan *interfaces.CompletionResult)
errorChan := make(chan error)

// Create a context with timeout same as the provider/request config
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

for _, plugin := range bifrost.plugins {
req, err = plugin.PreHook(&ctx, req)
ctx, req, err = plugin.PreHook(ctx, req)
if err != nil {
return nil, err
}
Expand All @@ -285,7 +275,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
select {
case result := <-responseChan:
for _, plugin := range bifrost.plugins {
result, err = plugin.PostHook(&ctx, result)
result, err = plugin.PostHook(ctx, result)

if err != nil {
return nil, err
Expand All @@ -295,8 +285,6 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
return result, nil
case err := <-errorChan:
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
}
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ go 1.21.1

require github.com/joho/godotenv v1.5.1

require github.com/maximhq/maxim-go v0.1.1 // indirect
require github.com/maximhq/maxim-go v0.1.1
4 changes: 2 additions & 2 deletions interfaces/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ type BifrostRequest struct {
}

type Plugin interface {
PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error)
PostHook(ctx *context.Context, result *CompletionResult) (*CompletionResult, error)
PreHook(ctx context.Context, req *BifrostRequest) (context.Context, *BifrostRequest, error)
PostHook(ctx context.Context, result *CompletionResult) (*CompletionResult, error)
}
7 changes: 5 additions & 2 deletions tests/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tests
import (
"bifrost"
"bifrost/interfaces"
"context"
"fmt"
"testing"
"time"
Expand All @@ -17,6 +18,8 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) {
"Tell me about artificial intelligence.",
}

ctx := context.Background()

go func() {
params := interfaces.ModelParameters{
ExtraParams: map[string]interface{}{
Expand All @@ -31,7 +34,7 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) {
StringInput: &text,
},
Params: &params,
})
}, ctx)
if err != nil {
fmt.Println("Error:", err)
} else {
Expand Down Expand Up @@ -61,7 +64,7 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) {
MessageInput: &messages,
},
Params: &params,
})
}, ctx)

if err != nil {
fmt.Printf("Error in Anthropic request %d: %v\n", index+1, err)
Expand Down
7 changes: 5 additions & 2 deletions tests/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tests
import (
"bifrost"
"bifrost/interfaces"
"context"
"fmt"
"testing"
"time"
Expand All @@ -12,6 +13,8 @@ import (
func setupOpenAIRequests(bifrost *bifrost.Bifrost) {
text := "Hello world!"

ctx := context.Background()

// Text completion request
go func() {
result, err := bifrost.TextCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{
Expand All @@ -20,7 +23,7 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) {
StringInput: &text,
},
Params: nil,
})
}, ctx)
if err != nil {
fmt.Println("Error:", err)
} else {
Expand Down Expand Up @@ -52,7 +55,7 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) {
MessageInput: &messages,
},
Params: nil,
})
}, ctx)
if err != nil {
fmt.Printf("Error in OpenAI request %d: %v\n", index+1, err)
} else {
Expand Down
10 changes: 5 additions & 5 deletions tests/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type Plugin struct {
logger *logging.Logger
}

func (plugin *Plugin) PreHook(ctx *context.Context, req *interfaces.BifrostRequest) (*interfaces.BifrostRequest, error) {
func (plugin *Plugin) PreHook(ctx context.Context, req *interfaces.BifrostRequest) (context.Context, *interfaces.BifrostRequest, error) {
traceID := time.Now().Format("20060102_150405000")

trace := plugin.logger.Trace(&logging.TraceConfig{
Expand All @@ -32,14 +32,14 @@ func (plugin *Plugin) PreHook(ctx *context.Context, req *interfaces.BifrostReque
trace.SetInput(fmt.Sprintf("New Request Incoming: %v", req))

// Store traceID in context
*ctx = context.WithValue(*ctx, traceIDKey, traceID)
ctx = context.WithValue(ctx, traceIDKey, traceID)

return req, nil
return ctx, req, nil
}

func (plugin *Plugin) PostHook(ctx *context.Context, res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) {
func (plugin *Plugin) PostHook(ctx context.Context, res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) {
// Get traceID from context
traceID, ok := (*ctx).Value(traceIDKey).(string)
traceID, ok := ctx.Value(traceIDKey).(string)
if !ok {
return res, fmt.Errorf("traceID not found in context")
}
Expand Down