diff --git a/bifrost.go b/bifrost.go index a331a69de..155fbfc2f 100644 --- a/bifrost.go +++ b/bifrost.go @@ -25,7 +25,7 @@ const ( type ChannelMessage struct { interfaces.BifrostRequest Response chan *interfaces.BifrostResponse - Err chan error + Err chan interfaces.BifrostError Type RequestType } @@ -107,7 +107,8 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interf for _, providerKey := range providerKeys { config, err := bifrost.account.GetConfigForProvider(providerKey) if err != nil { - return nil, fmt.Errorf("failed to get config for provider: %v", err) + bifrost.logger.Warn(fmt.Sprintf("failed to get config for provider, skipping init: %v", err)) + continue } if err := bifrost.prepareProvider(providerKey, config); err != nil { @@ -176,29 +177,47 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan for req := range queue { var result *interfaces.BifrostResponse var err error + var bifrostError *interfaces.BifrostError key, err := bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) if err != nil { - req.Err <- err + req.Err <- interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: err.Error(), + Error: err, + }, + } + continue } if req.Type == TextCompletionRequest { if req.Input.TextCompletionInput == nil { - err = fmt.Errorf("text not provided for text completion request") + bifrostError = &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "text not provided for text completion request", + }, + } } else { - result, err = provider.TextCompletion(req.Model, key, *req.Input.TextCompletionInput, req.Params) + result, bifrostError = provider.TextCompletion(req.Model, key, *req.Input.TextCompletionInput, req.Params) } } else if req.Type == ChatCompletionRequest { if req.Input.ChatCompletionInput == nil { - err = fmt.Errorf("chats not provided for chat completion request") + bifrostError = &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "chats not provided for chat completion request", + }, + } } else { - result, err = provider.ChatCompletion(req.Model, key, *req.Input.ChatCompletionInput, req.Params) + result, bifrostError = provider.ChatCompletion(req.Model, key, *req.Input.ChatCompletionInput, req.Params) } } - if err != nil { - req.Err <- err + if bifrostError != nil { + req.Err <- *bifrostError } else { req.Response <- result } @@ -237,28 +256,48 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr return queue, nil } -func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) { +func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, *interfaces.BifrostError) { if req == nil { - return nil, fmt.Errorf("bifrost request cannot be nil") + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "bifrost request cannot be nil", + }, + } } queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { - return nil, err + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: err.Error(), + }, + } } responseChan := make(chan *interfaces.BifrostResponse) - errorChan := make(chan error) + errorChan := make(chan interfaces.BifrostError) for _, plugin := range bifrost.plugins { req, err = plugin.PreHook(&ctx, req) if err != nil { - return nil, err + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: err.Error(), + }, + } } } if req == nil { - return nil, fmt.Errorf("bifrost request after plugin hooks cannot be nil") + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "bifrost request after plugin hooks cannot be nil", + }, + } } queue <- ChannelMessage{ @@ -275,38 +314,63 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo result, err = bifrost.plugins[i].PostHook(&ctx, result) if err != nil { - return nil, err + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: err.Error(), + }, + } } } return result, nil case err := <-errorChan: - return nil, err + return nil, &err } } -func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) { +func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, *interfaces.BifrostError) { if req == nil { - return nil, fmt.Errorf("bifrost request cannot be nil") + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "bifrost request cannot be nil", + }, + } } queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { - return nil, err + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: err.Error(), + }, + } } responseChan := make(chan *interfaces.BifrostResponse) - errorChan := make(chan error) + errorChan := make(chan interfaces.BifrostError) for _, plugin := range bifrost.plugins { req, err = plugin.PreHook(&ctx, req) if err != nil { - return nil, err + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: err.Error(), + }, + } } } if req == nil { - return nil, fmt.Errorf("bifrost request after pre plugin hooks cannot be nil") + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "bifrost request after plugin hooks cannot be nil", + }, + } } queue <- ChannelMessage{ @@ -323,13 +387,18 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo result, err = bifrost.plugins[i].PostHook(&ctx, result) if err != nil { - return nil, err + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: err.Error(), + }, + } } } return result, nil case err := <-errorChan: - return nil, err + return nil, &err } } diff --git a/interfaces/bifrost.go b/interfaces/bifrost.go index 4f4c67771..f298bf32c 100644 --- a/interfaces/bifrost.go +++ b/interfaces/bifrost.go @@ -228,3 +228,20 @@ type BifrostResponse struct { Usage LLMUsage `json:"usage"` ExtraFields BifrostResponseExtraFields `json:"extra_fields"` } + +type BifrostError struct { + EventID *string `json:"event_id"` + Type *string `json:"type"` + IsBifrostError bool `json:"is_bifrost_error"` + StatusCode *int `json:"status_code"` + Error ErrorField `json:"error"` +} + +type ErrorField struct { + Type *string `json:"type"` + Code *string `json:"code"` + Message string `json:"message"` + Error error `json:"error"` + Param interface{} `json:"param"` + EventID *string `json:"event_id"` +} diff --git a/interfaces/provider.go b/interfaces/provider.go index cb07a7596..bb5b6ddfc 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -23,11 +23,12 @@ type ProviderConfig struct { NetworkConfig NetworkConfig `json:"network_config"` MetaConfig *MetaConfig `json:"meta_config,omitempty"` ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` + Logger Logger `json:"logger"` } // Provider defines the interface for AI model providers type Provider interface { GetProviderKey() SupportedModelProvider - TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, error) - ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, error) + TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) + ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, *BifrostError) } diff --git a/providers/anthropic.go b/providers/anthropic.go index 74b40e382..15a58d783 100644 --- a/providers/anthropic.go +++ b/providers/anthropic.go @@ -49,6 +49,14 @@ type AnthropicChatResponse struct { } `json:"usage"` } +type AnthropicError struct { + Type string `json:"type"` + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` +} + // AnthropicProvider implements the Provider interface for Anthropic's Claude API type AnthropicProvider struct { logger interfaces.Logger @@ -116,20 +124,17 @@ func (provider *AnthropicProvider) PrepareToolChoices(params map[string]interfac return params } -// TextCompletion implements text completion using Anthropic's API -func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { - preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params)) - - // Merge additional parameters - requestBody := MergeConfig(map[string]interface{}{ - "model": model, - "prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text), - }, preparedParams) - +func (provider *AnthropicProvider) CompleteRequest(requestBody map[string]interface{}, url string, key string) ([]byte, *interfaces.BifrostError) { // Marshal the request body jsonData, err := json.Marshal(requestBody) if err != nil { - return nil, fmt.Errorf("error marshaling request: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error marshaling request", + Error: err, + }, + } } // Create the request with the JSON body @@ -147,27 +152,84 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param // Send the request if err := provider.client.Do(req, resp); err != nil { - return nil, fmt.Errorf("error sending request: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error sending request", + Error: err, + }, + } } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, fmt.Errorf("anthropic error: %s", resp.Body()) + var errorResp AnthropicError + if err := json.Unmarshal(resp.Body(), &errorResp); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing error response", + Error: err, + }, + } + } + + statusCode := resp.StatusCode() + + return nil, &interfaces.BifrostError{ + Type: &errorResp.Type, + IsBifrostError: false, + StatusCode: &statusCode, + Error: interfaces.ErrorField{ + Type: &errorResp.Error.Type, + Message: errorResp.Error.Message, + }, + } } // Read the response body body := resp.Body() + return body, nil +} + +// TextCompletion implements text completion using Anthropic's API +func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { + preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params)) + + // Merge additional parameters + requestBody := MergeConfig(map[string]interface{}{ + "model": model, + "prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text), + }, preparedParams) + + body, err := provider.CompleteRequest(requestBody, "https://api.anthropic.com/v1/complete", key) + if err != nil { + return nil, err + } + // Parse the response var response AnthropicTextResponse if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("error parsing response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing response", + Error: err, + }, + } } // Parse raw response var rawResponse interface{} if err := json.Unmarshal(body, &rawResponse); err != nil { - return nil, fmt.Errorf("error parsing raw response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing raw response", + Error: err, + }, + } } // Create the completion result @@ -198,7 +260,7 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param } // ChatCompletion implements chat completion using Anthropic's API -func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { +func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { // Format messages for Anthropic API var formattedMessages []map[string]interface{} for _, msg := range messages { @@ -264,44 +326,33 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] "messages": formattedMessages, }, preparedParams) - jsonBody, err := json.Marshal(requestBody) + body, err := provider.CompleteRequest(requestBody, "https://api.anthropic.com/v1/messages", key) if err != nil { - return nil, fmt.Errorf("error marshaling request: %v", err) + return nil, err } - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - req.SetRequestURI("https://api.anthropic.com/v1/messages") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - req.Header.Set("x-api-key", key) - req.Header.Set("anthropic-version", "2023-06-01") - req.SetBody(jsonBody) - - // Make request - if err := provider.client.Do(req, resp); err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - return nil, fmt.Errorf("anthropic error: %s", resp.Body()) - } - - // Decode structured response + // Decode response var response AnthropicChatResponse - if err := json.Unmarshal(resp.Body(), &response); err != nil { - return nil, fmt.Errorf("error decoding structured response: %v", err) + if err := json.Unmarshal(body, &response); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error decoding response", + Error: err, + }, + } } // Decode raw response var rawResponse interface{} - if err := json.Unmarshal(resp.Body(), &rawResponse); err != nil { - return nil, fmt.Errorf("error decoding raw response: %v", err) + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing raw response", + Error: err, + }, + } } // Process the response into our BifrostResponse format diff --git a/providers/bedrock.go b/providers/bedrock.go index a245b6e70..890b4f41b 100644 --- a/providers/bedrock.go +++ b/providers/bedrock.go @@ -3,7 +3,6 @@ package providers import ( "bytes" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -98,6 +97,10 @@ type BedrockAnthropicToolSpec struct { } `json:"inputSchema"` } +type BedrockError struct { + Message string `json:"message"` +} + type BedrockProvider struct { client *http.Client meta *interfaces.MetaConfig @@ -114,9 +117,14 @@ func (provider *BedrockProvider) GetProviderKey() interfaces.SupportedModelProvi return interfaces.Bedrock } -func (provider *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey string) (*http.Request, error) { +func (provider *BedrockProvider) CompleteRequest(requestBody map[string]interface{}, path string, accessKey string) ([]byte, *interfaces.BifrostError) { if provider.meta == nil { - return nil, errors.New("meta config for bedrock is not provided") + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "meta config for bedrock is not provided", + }, + } } region := "us-east-1" @@ -124,20 +132,82 @@ func (provider *BedrockProvider) PrepareReq(path string, jsonData []byte, access region = *provider.meta.Region } + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error marshaling request", + Error: err, + }, + } + } + // Create the request with the JSON body - req, err := http.NewRequest("POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonData)) + req, err := http.NewRequest("POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonBody)) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error creating request", + Error: err, + }, + } } if err := SignAWSRequest(req, accessKey, *provider.meta.SecretAccessKey, provider.meta.SessionToken, region, "bedrock"); err != nil { return nil, err } - return req, nil + // Execute the request + resp, err := provider.client.Do(req) + if err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error sending request", + Error: err, + }, + } + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error reading request", + Error: err, + }, + } + } + + if resp.StatusCode != http.StatusOK { + var errorResp BedrockError + if err := json.Unmarshal(body, &errorResp); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing error response", + Error: err, + }, + } + } + + return nil, &interfaces.BifrostError{ + StatusCode: &resp.StatusCode, + Error: interfaces.ErrorField{ + Message: errorResp.Message, + }, + } + } + + return body, nil } -func (provider *BedrockProvider) GetTextCompletionResult(result []byte, model string) (*interfaces.BifrostResponse, error) { +func (provider *BedrockProvider) GetTextCompletionResult(result []byte, model string) (*interfaces.BifrostResponse, *interfaces.BifrostError) { switch model { case "anthropic.claude-instant-v1:2": fallthrough @@ -146,7 +216,13 @@ func (provider *BedrockProvider) GetTextCompletionResult(result []byte, model st case "anthropic.claude-v2:1": var response BedrockAnthropicTextResponse if err := json.Unmarshal(result, &response); err != nil { - return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing response", + Error: err, + }, + } } return &interfaces.BifrostResponse{ @@ -178,7 +254,13 @@ func (provider *BedrockProvider) GetTextCompletionResult(result []byte, model st case "mistral.mistral-small-2402-v1:0": var response BedrockMistralTextResponse if err := json.Unmarshal(result, &response); err != nil { - return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing response", + Error: err, + }, + } } var choices []interfaces.BifrostResponseChoice @@ -202,10 +284,15 @@ func (provider *BedrockProvider) GetTextCompletionResult(result []byte, model st }, nil } - return nil, fmt.Errorf("invalid model choice: %s", model) + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: fmt.Sprintf("invalid model choice: %s", model), + }, + } } -func (provider *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Message, model string) (map[string]interface{}, error) { +func (provider *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Message, model string) (map[string]interface{}, *interfaces.BifrostError) { switch model { case "anthropic.claude-instant-v1:2": fallthrough @@ -315,7 +402,12 @@ func (provider *BedrockProvider) PrepareChatCompletionMessages(messages []interf return body, nil } - return nil, fmt.Errorf("invalid model choice: %s", model) + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: fmt.Sprintf("invalid model choice: %s", model), + }, + } } func (provider *BedrockProvider) GetChatCompletionTools(params *interfaces.ModelParameters, model string) []BedrockAnthropicToolCall { @@ -377,51 +469,33 @@ func (provider *BedrockProvider) PrepareTextCompletionParams(params map[string]i return params } -func (provider *BedrockProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { +func (provider *BedrockProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params), model) requestBody := MergeConfig(map[string]interface{}{ "prompt": text, }, preparedParams) - // Marshal the request body - jsonData, err := json.Marshal(requestBody) + body, err := provider.CompleteRequest(requestBody, fmt.Sprintf("%s/invoke", model), key) if err != nil { - return nil, fmt.Errorf("error marshaling request: %v", err) - } - - // Create the signed request with correct operation name - req, err := provider.PrepareReq(fmt.Sprintf("%s/invoke", model), jsonData, key) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - // Execute the request - resp, err := provider.client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to execute request: %v", err) - } - defer resp.Body.Close() - - // Read response body - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %v", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("bedrock API error: %s", string(body)) + return nil, err } result, err := provider.GetTextCompletionResult(body, model) if err != nil { - return nil, fmt.Errorf("failed to parse response body: %v", err) + return nil, err } // Parse raw response var rawResponse interface{} if err := json.Unmarshal(body, &rawResponse); err != nil { - return nil, fmt.Errorf("failed to parse raw response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing raw response", + Error: err, + }, + } } result.ExtraFields.RawResponse = rawResponse @@ -429,10 +503,10 @@ func (provider *BedrockProvider) TextCompletion(model, key, text string, params return result, nil } -func (provider *BedrockProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { +func (provider *BedrockProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { messageBody, err := provider.PrepareChatCompletionMessages(messages, model) if err != nil { - return nil, fmt.Errorf("error preparing messages: %v", err) + return nil, err } preparedParams := PrepareParams(params) @@ -444,12 +518,6 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []in requestBody := MergeConfig(messageBody, preparedParams) - // Marshal the request body - jsonData, err := json.Marshal(requestBody) - if err != nil { - return nil, fmt.Errorf("error marshaling request: %v", err) - } - // Format the path with proper model identifier path := fmt.Sprintf("%s/converse", model) @@ -463,37 +531,32 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []in } // Create the signed request - req, err := provider.PrepareReq(path, jsonData, key) + body, err := provider.CompleteRequest(requestBody, path, key) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - // Execute the request - resp, err := provider.client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to execute request: %v", err) - } - defer resp.Body.Close() - - // Read response body - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %v", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("bedrock API error: %s", string(body)) + return nil, err } var response BedrockChatResponse if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing response", + Error: err, + }, + } } // Parse raw response var rawResponse interface{} if err := json.Unmarshal(body, &rawResponse); err != nil { - return nil, fmt.Errorf("failed to parse raw response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing raw response", + Error: err, + }, + } } var choices []interfaces.BifrostResponseChoice diff --git a/providers/cohere.go b/providers/cohere.go index 4f14e3d67..418079a11 100644 --- a/providers/cohere.go +++ b/providers/cohere.go @@ -56,6 +56,10 @@ type CohereChatResponse struct { ToolCalls []CohereToolCall `json:"tool_calls"` } +type CohereError struct { + Message string `json:"message"` +} + // OpenAIProvider implements the Provider interface for OpenAI type CohereProvider struct { client *fasthttp.Client @@ -76,11 +80,16 @@ func (provider *CohereProvider) GetProviderKey() interfaces.SupportedModelProvid return interfaces.Cohere } -func (provider *CohereProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { - return nil, fmt.Errorf("text completion is not supported by Cohere") +func (provider *CohereProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "text completion is not supported by cohere provider", + }, + } } -func (provider *CohereProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { +func (provider *CohereProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { // Get the last message and chat history lastMessage := messages[len(messages)-1] chatHistory := messages[:len(messages)-1] @@ -140,7 +149,13 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int // Marshal request body jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, fmt.Errorf("error marshaling request: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error marshaling request", + Error: err, + }, + } } // Create request @@ -157,12 +172,37 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int // Make request if err := provider.client.Do(req, resp); err != nil { - return nil, fmt.Errorf("error making request: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error sending request", + Error: err, + }, + } } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, fmt.Errorf("cohere error: %s", resp.Body()) + var errorResp CohereError + if err := json.Unmarshal(resp.Body(), &errorResp); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing error response", + Error: err, + }, + } + } + + statusCode := resp.StatusCode() + + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: interfaces.ErrorField{ + Message: errorResp.Message, + }, + } } // Read response body @@ -171,13 +211,25 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int // Decode response var response CohereChatResponse if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing response", + Error: err, + }, + } } // Parse raw response var rawResponse interface{} if err := json.Unmarshal(body, &rawResponse); err != nil { - return nil, fmt.Errorf("failed to parse raw response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing raw response", + Error: err, + }, + } } // Transform tool calls if present diff --git a/providers/openai.go b/providers/openai.go index c3b5cf2ed..b7f8b44ae 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -2,14 +2,14 @@ package providers import ( "encoding/json" - "fmt" + "time" "github.com/maximhq/bifrost/interfaces" "github.com/valyala/fasthttp" ) -type OpenAIResponse struct { +type OpenAIChatResponse struct { ID string `json:"id"` Object string `json:"object"` // text.completion or chat.completion Choices []interfaces.BifrostResponseChoice `json:"choices"` @@ -20,6 +20,18 @@ type OpenAIResponse struct { Usage interfaces.LLMUsage `json:"usage"` } +type OpenAIError struct { + EventID string `json:"event_id"` + Type string `json:"type"` + Error struct { + Type string `json:"type"` + Code string `json:"code"` + Message string `json:"message"` + Param interface{} `json:"param"` + EventID string `json:"event_id"` + } `json:"error"` +} + // OpenAIProvider implements the Provider interface for OpenAI type OpenAIProvider struct { logger interfaces.Logger @@ -43,11 +55,16 @@ func (provider *OpenAIProvider) GetProviderKey() interfaces.SupportedModelProvid } // TextCompletion performs text completion -func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { - return nil, fmt.Errorf("text completion is not supported by OpenAI") +func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "text completion is not supported by openai provider", + }, + } } -func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { +func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { // Format messages for OpenAI API var formattedMessages []map[string]interface{} for _, msg := range messages { @@ -96,7 +113,13 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, fmt.Errorf("error marshaling request: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error marshaling request", + Error: err, + }, + } } // Create request @@ -113,26 +136,68 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int // Make request if err := provider.client.Do(req, resp); err != nil { - return nil, fmt.Errorf("error making request: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error sending request", + Error: err, + }, + } } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, fmt.Errorf("OpenAI error: %s", resp.Body()) + var errorResp OpenAIError + if err := json.Unmarshal(resp.Body(), &errorResp); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing error response", + Error: err, + }, + } + } + + statusCode := resp.StatusCode() + + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + EventID: &errorResp.EventID, + StatusCode: &statusCode, + Error: interfaces.ErrorField{ + Type: &errorResp.Error.Type, + Code: &errorResp.Error.Code, + Message: errorResp.Error.Message, + Param: errorResp.Error.Param, + EventID: &errorResp.Error.EventID, + }, + } } body := resp.Body() // Decode structured response - var response OpenAIResponse + var response OpenAIChatResponse if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("error decoding structured response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing response", + Error: err, + }, + } } // Decode raw response var rawResponse interface{} if err := json.Unmarshal(body, &rawResponse); err != nil { - return nil, fmt.Errorf("error decoding raw response: %v", err) + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing raw response", + Error: err, + }, + } } result := &interfaces.BifrostResponse{ diff --git a/providers/utils.go b/providers/utils.go index 9acbef7ab..b8ca0ec13 100644 --- a/providers/utils.go +++ b/providers/utils.go @@ -5,7 +5,6 @@ import ( "context" "crypto/sha256" "encoding/hex" - "fmt" "io" "net/http" "reflect" @@ -81,7 +80,7 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} { return flatParams } -func SignAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken *string, region, service string) error { +func SignAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken *string, region, service string) *interfaces.BifrostError { // Set required headers before signing req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") @@ -91,7 +90,13 @@ func SignAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken if req.Body != nil { bodyBytes, err := io.ReadAll(req.Body) if err != nil { - return fmt.Errorf("failed to read request body: %v", err) + return &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error reading request body", + Error: err, + }, + } } // Restore the body for subsequent reads req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -118,7 +123,13 @@ func SignAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken })), ) if err != nil { - return fmt.Errorf("failed to load AWS config: %v", err) + return &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "failed to load aws config", + Error: err, + }, + } } // Create the AWS signer @@ -127,12 +138,24 @@ func SignAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken // Get credentials creds, err := cfg.Credentials.Retrieve(context.TODO()) if err != nil { - return fmt.Errorf("failed to retrieve credentials: %v", err) + return &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "failed to retrieve aws credentials", + Error: err, + }, + } } // Sign the request with AWS Signature V4 if err := signer.SignHTTP(context.TODO(), creds, req, bodyHash, service, region, time.Now()); err != nil { - return fmt.Errorf("failed to sign request: %v", err) + return &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "failed to sign request", + Error: err, + }, + } } return nil diff --git a/tests/anthropic_test.go b/tests/anthropic_test.go index 9def641eb..a65888f26 100644 --- a/tests/anthropic_test.go +++ b/tests/anthropic_test.go @@ -34,7 +34,7 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { Params: ¶ms, }, ctx) if err != nil { - fmt.Println("Error:", err) + fmt.Println("Error:", err.Error.Message) } else { fmt.Println("🤖 Text Completion Result:", *result.Choices[0].Message.Content) } @@ -66,7 +66,7 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { }, ctx) if err != nil { - fmt.Printf("Error in Anthropic request %d: %v\n", index+1, err) + fmt.Printf("Error in Anthropic request %d: %v\n", index+1, err.Error.Message) } else { fmt.Printf("🤖 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content) } @@ -109,9 +109,9 @@ func setupAnthropicImageTests(bifrost *bifrost.Bifrost, ctx context.Context) { Params: ¶ms, }, ctx) if err != nil { - fmt.Printf("Error in Anthropic URL image request: %v\n", err) + fmt.Printf("Error in Anthropic URL image request: %v\n", err.Error.Message) } else { - fmt.Printf("🐒 URL Image Result: %s\n", result.Choices[0].Message.Content) + fmt.Printf("🐒 URL Image Result: %s\n", *result.Choices[0].Message.Content) } }() @@ -137,9 +137,9 @@ func setupAnthropicImageTests(bifrost *bifrost.Bifrost, ctx context.Context) { Params: ¶ms, }, ctx) if err != nil { - fmt.Printf("Error in Anthropic base64 image request: %v\n", err) + fmt.Printf("Error in Anthropic base64 image request: %v\n", err.Error.Message) } else { - fmt.Printf("🐒 Base64 Image Result: %s\n", result.Choices[0].Message.Content) + fmt.Printf("🐒 Base64 Image Result: %s\n", *result.Choices[0].Message.Content) } }() } @@ -196,7 +196,7 @@ func setupAnthropicToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { }, ctx) if err != nil { - fmt.Printf("Error in Anthropic tool call request %d: %v\n", index+1, err) + fmt.Printf("Error in Anthropic tool call request %d: %v\n", index+1, err.Error.Message) } else { toolCall := *result.Choices[1].Message.ToolCalls fmt.Printf("🤖 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments) diff --git a/tests/bedrock_test.go b/tests/bedrock_test.go index 1777d8612..898bdb128 100644 --- a/tests/bedrock_test.go +++ b/tests/bedrock_test.go @@ -34,7 +34,7 @@ func setupBedrockRequests(bifrost *bifrost.Bifrost) { Params: ¶ms, }, ctx) if err != nil { - fmt.Println("Error:", err) + fmt.Println("Error:", err.Error.Message) } else { fmt.Println("🤖 Text Completion Result:", *result.Choices[0].Message.Content) } @@ -66,7 +66,7 @@ func setupBedrockRequests(bifrost *bifrost.Bifrost) { }, ctx) if err != nil { - fmt.Printf("Error in Bedrock request %d: %v\n", index+1, err) + fmt.Printf("Error in Bedrock request %d: %v\n", index+1, err.Error.Message) } else { fmt.Printf("🤖 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content) } @@ -110,9 +110,9 @@ func setupBedrockImageTests(bifrost *bifrost.Bifrost, ctx context.Context) { Params: ¶ms, }, ctx) if err != nil { - fmt.Printf("Error in Bedrock base64 image request: %v\n", err) + fmt.Printf("Error in Bedrock base64 image request: %v\n", err.Error.Message) } else { - fmt.Printf("🐒 Base64 Image Result: %s\n", result.Choices[0].Message.Content) + fmt.Printf("🐒 Base64 Image Result: %s\n", *result.Choices[0].Message.Content) } }() } @@ -169,7 +169,7 @@ func setupBedrockToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { }, ctx) if err != nil { - fmt.Printf("Error in Bedrock tool call request %d: %v\n", index+1, err) + fmt.Printf("Error in Bedrock tool call request %d: %v\n", index+1, err.Error.Message) } else { if result.Choices[0].Message.ToolCalls != nil && len(*result.Choices[0].Message.ToolCalls) > 0 { toolCall := *result.Choices[0].Message.ToolCalls diff --git a/tests/cohere_test.go b/tests/cohere_test.go index d9f2bcdec..030edd65e 100644 --- a/tests/cohere_test.go +++ b/tests/cohere_test.go @@ -25,7 +25,7 @@ func setupCohereRequests(bifrost *bifrost.Bifrost) { Params: nil, }, ctx) if err != nil { - fmt.Println("Error:", err) + fmt.Println("Error:", err.Error.Message) } else { fmt.Println("🐒 Text Completion Result:", result.Choices[0].Message.Content) } @@ -56,7 +56,7 @@ func setupCohereRequests(bifrost *bifrost.Bifrost) { Params: nil, }, ctx) if err != nil { - fmt.Printf("Error in Cohere request %d: %v\n", index+1, err) + fmt.Printf("Error in Cohere request %d: %v\n", index+1, err.Error.Message) } else { fmt.Printf("🐒 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content) } @@ -115,7 +115,7 @@ func setupCohereToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { Params: ¶ms, }, ctx) if err != nil { - fmt.Printf("Error in Cohere tool call request %d: %v\n", index+1, err) + fmt.Printf("Error in Cohere tool call request %d: %v\n", index+1, err.Error.Message) } else { toolCall := *result.Choices[0].Message.ToolCalls fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments) diff --git a/tests/openai_test.go b/tests/openai_test.go index 91127452f..655d2d98a 100644 --- a/tests/openai_test.go +++ b/tests/openai_test.go @@ -27,7 +27,7 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { Params: nil, }, ctx) if err != nil { - fmt.Println("Error:", err) + fmt.Println("Error:", err.Error.Message) } else { fmt.Println("🐒 Text Completion Result:", result.Choices[0].Message.Content) } @@ -58,7 +58,7 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { Params: nil, }, ctx) if err != nil { - fmt.Printf("Error in OpenAI request %d: %v\n", index+1, err) + fmt.Printf("Error in OpenAI request %d: %v\n", index+1, err.Error.Message) } else { fmt.Printf("🐒 Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content) } @@ -94,9 +94,9 @@ func setupOpenAIImageTests(bifrost *bifrost.Bifrost, ctx context.Context) { Params: nil, }, ctx) if err != nil { - fmt.Printf("Error in OpenAI URL image request: %v\n", err) + fmt.Printf("Error in OpenAI URL image request: %v\n", err.Error.Message) } else { - fmt.Printf("🐒 URL Image Result: %s\n", result.Choices[0].Message.Content) + fmt.Printf("🐒 URL Image Result: %s\n", *result.Choices[0].Message.Content) } }() @@ -120,9 +120,9 @@ func setupOpenAIImageTests(bifrost *bifrost.Bifrost, ctx context.Context) { Params: nil, }, ctx) if err != nil { - fmt.Printf("Error in OpenAI base64 image request: %v\n", err) + fmt.Printf("Error in OpenAI base64 image request: %v\n", err.Error.Message) } else { - fmt.Printf("🐒 Base64 Image Result: %s\n", result.Choices[0].Message.Content) + fmt.Printf("🐒 Base64 Image Result: %s\n", *result.Choices[0].Message.Content) } }() } @@ -175,7 +175,7 @@ func setupOpenAIToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { Params: ¶ms, }, ctx) if err != nil { - fmt.Printf("Error in OpenAI tool call request %d: %v\n", index+1, err) + fmt.Printf("Error in OpenAI tool call request %d: %v\n", index+1, err.Error.Message) } else { toolCall := *result.Choices[0].Message.ToolCalls fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments)