diff --git a/bifrost.go b/bifrost.go index a3e41e767..f3c570c34 100644 --- a/bifrost.go +++ b/bifrost.go @@ -43,6 +43,8 @@ func createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider return providers.NewOpenAIProvider(config), nil case interfaces.Anthropic: return providers.NewAnthropicProvider(config), nil + case interfaces.Bedrock: + return providers.NewBedrockProvider(config), nil default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } @@ -99,6 +101,7 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, configs map[i } bifrost.requestQueues = make(map[interfaces.SupportedModelProvider]chan ChannelMessage) + bifrost.configs = configs // Create buffered channels for each provider and start workers for _, providerKey := range providerKeys { @@ -184,9 +187,9 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan } if req.Type == TextCompletionRequest { - result, err = provider.TextCompletion(req.Model, key, *req.Input.StringInput, req.Params) + result, err = provider.TextCompletion(req.Model, key, *req.Input.TextInput, req.Params) } else if req.Type == ChatCompletionRequest { - result, err = provider.ChatCompletion(req.Model, key, *req.Input.MessageInput, req.Params) + result, err = provider.ChatCompletion(req.Model, key, *req.Input.ChatInput, req.Params) } if err != nil { diff --git a/go.mod b/go.mod index 5f4f228b0..fdf984f83 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,27 @@ module bifrost -go 1.21.1 +go 1.23.0 + +toolchain go1.24.1 require github.com/joho/godotenv v1.5.1 -require github.com/maximhq/maxim-go v0.1.1 +require ( + github.com/aws/aws-sdk-go-v2 v1.36.3 + github.com/aws/aws-sdk-go-v2/config v1.29.11 + github.com/maximhq/maxim-go v0.1.1 +) + +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.17.64 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.2 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.17 // indirect + github.com/aws/smithy-go v1.22.2 // indirect +) diff --git a/go.sum b/go.sum index ebeda6c0c..1fa5e6c25 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,29 @@ +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.11 h1:/hkJIxaQzFQy0ebFjG5NHmAcLCrvNSuXeHnxLfeCz1Y= +github.com/aws/aws-sdk-go-v2/config v1.29.11/go.mod h1:OFPRZVQxC4mKqy2Go6Cse/m9NOStAo6YaMvAcTMUROg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.64 h1:NH4RAQJEXBDQDUudTqMNHdyyEVa5CvMn0tQicqv48jo= +github.com/aws/aws-sdk-go-v2/credentials v1.17.64/go.mod h1:tUoJfj79lzEcalHDbyNkpnZZTRg/2ayYOK/iYnRfPbo= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.2 h1:pdgODsAhGo4dvzC3JAG5Ce0PX8kWXrTZGx+jxADD+5E= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.2/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.2 h1:wK8O+j2dOolmpNVY1EWIbLgxrGCHJKVPm08Hv/u80M8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.2/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.17 h1:PZV5W8yk4OtH1JAuhV2PXwwO9v5G5Aoj+eMCn4T+1Kc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.17/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/maximhq/maxim-go v0.1.1 h1:69uUQjjDPmUGcKg/M4/3AO0fbD+70Agt66pH/UCsI5M= diff --git a/interfaces/plugin.go b/interfaces/plugin.go index bd322a87d..0de9d9144 100644 --- a/interfaces/plugin.go +++ b/interfaces/plugin.go @@ -3,8 +3,8 @@ package interfaces import "context" type RequestInput struct { - StringInput *string - MessageInput *[]Message + TextInput *string + ChatInput *[]Message } type BifrostRequest struct { diff --git a/interfaces/provider.go b/interfaces/provider.go index 2eb6313ee..a3cf5548f 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -1,5 +1,7 @@ package interfaces +import "encoding/json" + // LLMUsage represents token usage information type LLMUsage struct { PromptTokens int `json:"prompt_tokens"` @@ -55,9 +57,11 @@ type FunctionCall struct { // ToolCall represents a tool call in a message type ToolCall struct { - Type string `json:"type"` - ID string `json:"id"` - Function FunctionCall `json:"function"` + Type *string `json:"type"` + ID string `json:"id"` + Name *string `json:"name"` + Input json.RawMessage `json:"input"` + Function *FunctionCall `json:"function"` } // ModelChatMessageRole represents the role of a chat message @@ -73,18 +77,19 @@ const ( // CompletionResponseChoice represents a choice in the completion response type CompletionResponseChoice struct { - Role ModelChatMessageRole `json:"role"` - Content string `json:"content"` - FunctionCall *FunctionCall `json:"function_call"` - ToolCalls *[]ToolCall `json:"tool_calls"` + Role ModelChatMessageRole `json:"role"` + Content string `json:"content"` + Image json.RawMessage `json:"image"` + ToolCalls *[]ToolCall `json:"tool_calls"` } // CompletionResultChoice represents a choice in the completion result type CompletionResultChoice struct { - Index int `json:"index"` - Message CompletionResponseChoice `json:"message"` - FinishReason *string `json:"finish_reason"` - LogProbs *interface{} `json:"logprobs"` + Index int `json:"index"` + Message CompletionResponseChoice `json:"message"` + StopReason *string `json:"stop_reason"` + Stop *string `json:"stop"` + LogProbs *interface{} `json:"logprobs"` } // ToolResult represents the result of a tool call @@ -147,27 +152,19 @@ const ( Lmstudio SupportedModelProvider = "lmstudio" ) -type Role string - -const ( - UserRole Role = "user" - AssistantRole Role = "assistant" - SystemRole Role = "system" -) - type Message struct { //* strict check for roles - Role Role `json:"role"` + Role ModelChatMessageRole `json:"role"` //* need to make sure either content or imagecontent is provided Content *string `json:"content"` ImageContent *ImageContent `json:"imageContent"` + ToolCalls *[]ToolCall `json:"toolCall"` } type ImageContent struct { - Type string `json:"type"` - ImageURL struct { - URL string `json:"url"` - } `json:"image_url"` + Type string `json:"type"` + URL string `json:"url"` + MediaType string `json:"media_type"` } // type Content struct { diff --git a/providers/anthropic.go b/providers/anthropic.go index d181a4d19..0fe2394e1 100644 --- a/providers/anthropic.go +++ b/providers/anthropic.go @@ -8,6 +8,8 @@ import ( "io" "net/http" "time" + + "github.com/maximhq/maxim-go" ) type AnthropicTextResponse struct { @@ -62,6 +64,8 @@ func (provider *AnthropicProvider) GetProviderKey() interfaces.SupportedModelPro // TextCompletion implements text completion using Anthropic's API func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { + startTime := time.Now() + preparedParams := PrepareParams(params) // Merge additional parameters @@ -122,6 +126,9 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param return nil, fmt.Errorf("error parsing response: %v", err) } + // Calculate latency + latency := time.Since(startTime).Seconds() + // Create the completion result completionResult := &interfaces.CompletionResult{ ID: response.ID, @@ -138,6 +145,7 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + Latency: &latency, }, Model: response.Model, Provider: interfaces.Anthropic, @@ -214,7 +222,7 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] // Process the response into our CompletionResult format var content string var toolCalls []interfaces.ToolCall - var finishReason string + var stopReason string // Process content and tool calls for _, c := range anthropicResponse.Content { @@ -228,14 +236,14 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] case "tool_use": if c.ToolUse != nil { toolCalls = append(toolCalls, interfaces.ToolCall{ - Type: "function", + Type: maxim.StrPtr("function"), ID: c.ToolUse.ID, - Function: interfaces.FunctionCall{ + Function: &interfaces.FunctionCall{ Name: c.ToolUse.Name, Arguments: string(must(json.Marshal(c.ToolUse.Input))), }, }) - finishReason = "tool_calls" + stopReason = "tool_calls" } } } @@ -251,7 +259,7 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] Content: content, ToolCalls: &toolCalls, }, - FinishReason: &finishReason, + StopReason: &stopReason, }, }, Usage: interfaces.LLMUsage{ diff --git a/providers/bedrock.go b/providers/bedrock.go new file mode 100644 index 000000000..1f5309df8 --- /dev/null +++ b/providers/bedrock.go @@ -0,0 +1,427 @@ +package providers + +import ( + "bifrost/interfaces" + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +type BedrockAnthropicTextResponse struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Stop string `json:"stop"` +} + +type BedrockMistralTextResponse struct { + Outputs []struct { + Text string `json:"text"` + StopReason string `json:"stop_reason"` + } `json:"outputs"` +} + +type BedrockChatResponse struct { + Metrics struct { + Latency int `json:"latencyMs"` + } `json:"metrics"` + Output struct { + Message struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + Role string `json:"role"` + } `json:"message"` + } `json:"output"` + StopReason string `json:"stopReason"` + Usage struct { + InputTokens int `json:"inputTokens"` + OutputTokens int `json:"outputTokens"` + TotalTokens int `json:"totalTokens"` + } `json:"usage"` +} + +type BedrockAnthropicSystemMessage struct { + Text string `json:"text"` +} + +type BedrockAnthropicTextMessage struct { + Type string `json:"type"` + Text string `json:"text"` +} + +type BedrockMistralContent struct { + Text string `json:"text"` +} + +type BedrockMistralChatMessage struct { + Role interfaces.ModelChatMessageRole `json:"role"` + Content []BedrockMistralContent `json:"content"` + ToolCalls *[]BedrockMistralToolCall `json:"tool_calls,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` +} + +type BedrockAnthropicImageMessage struct { + Type string `json:"type"` + Source BedrockAnthropicImageSource `json:"source"` +} + +type BedrockAnthropicImageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type BedrockMistralToolCall struct { + ID string `json:"id"` + Function interfaces.FunctionCall `json:"function"` +} + +type BedrockProvider struct { + client *http.Client + meta *interfaces.BedrockMetaConfig +} + +func NewBedrockProvider(config *interfaces.ProviderConfig) *BedrockProvider { + return &BedrockProvider{ + client: &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)}, + meta: config.MetaConfig.BedrockMetaConfig, + } +} + +func (p *BedrockProvider) GetProviderKey() interfaces.SupportedModelProvider { + return interfaces.Bedrock +} + +func (p *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey string) (*http.Request, error) { + if p.meta == nil { + return nil, errors.New("meta config for bedrock is not provided") + } + + region := "us-east-1" + if p.meta.Region != nil { + region = *p.meta.Region + } + + // Create the request with the JSON body + req, err := http.NewRequest("POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + if err := SignAWSRequest(req, accessKey, p.meta.SecretAccessKey, p.meta.SessionToken, region, "bedrock"); err != nil { + return nil, err + } + + return req, nil +} + +func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) (*interfaces.CompletionResult, error) { + switch model { + case "anthropic.claude-instant-v1:2": + fallthrough + case "anthropic.claude-v2": + fallthrough + case "anthropic.claude-v2:1": + var response BedrockAnthropicTextResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + } + + return &interfaces.CompletionResult{ + Choices: []interfaces.CompletionResultChoice{ + { + Index: 0, + Message: interfaces.CompletionResponseChoice{ + Role: interfaces.RoleAssistant, + Content: response.Completion, + }, + StopReason: &response.StopReason, + Stop: &response.Stop, + }, + }, + }, nil + + case "mistral.mixtral-8x7b-instruct-v0:1": + fallthrough + case "mistral.mistral-7b-instruct-v0:2": + fallthrough + case "mistral.mistral-large-2402-v1:0": + fallthrough + case "mistral.mistral-large-2407-v1:0": + fallthrough + case "mistral.mistral-small-2402-v1:0": + var response BedrockMistralTextResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + } + + var choices []interfaces.CompletionResultChoice + for i, output := range response.Outputs { + choices = append(choices, interfaces.CompletionResultChoice{ + Index: i, + Message: interfaces.CompletionResponseChoice{ + Role: interfaces.RoleAssistant, + Content: output.Text, + }, + StopReason: &output.StopReason, + }) + } + + return &interfaces.CompletionResult{ + Choices: choices, + }, nil + } + + return nil, fmt.Errorf("invalid model choice: %s", model) +} + +func (p *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Message, model string) (map[string]interface{}, error) { + switch model { + case "anthropic.claude-instant-v1:2": + fallthrough + case "anthropic.claude-v2": + fallthrough + case "anthropic.claude-v2:1": + fallthrough + case "anthropic.claude-3-sonnet-20240229-v1:0": + fallthrough + case "anthropic.claude-3-5-sonnet-20240620-v1:0": + fallthrough + case "anthropic.claude-3-5-sonnet-20241022-v2:0": + fallthrough + case "anthropic.claude-3-5-haiku-20241022-v1:0": + fallthrough + case "anthropic.claude-3-opus-20240229-v1:0": + fallthrough + case "anthropic.claude-3-7-sonnet-20250219-v1:0": + // Add system messages if present + var systemMessages []BedrockAnthropicSystemMessage + for _, msg := range messages { + if msg.Role == interfaces.RoleSystem { + //TODO handling image inputs here + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *msg.Content, + }) + } + } + + // Format messages for Bedrock API + var bedrockMessages []map[string]interface{} + for _, msg := range messages { + if msg.Role != interfaces.RoleSystem { + var content any + if msg.Content != nil { + content = BedrockAnthropicTextMessage{ + Type: "text", + Text: *msg.Content, + } + } else if msg.ImageContent != nil { + content = BedrockAnthropicImageMessage{ + Type: "image", + Source: BedrockAnthropicImageSource{ + Type: msg.ImageContent.Type, + MediaType: msg.ImageContent.MediaType, + Data: msg.ImageContent.URL, + }, + } + } + + bedrockMessages = append(bedrockMessages, map[string]interface{}{ + "role": msg.Role, + "content": []interface{}{content}, + }) + } + } + + body := map[string]interface{}{ + "messages": bedrockMessages, + } + + if len(systemMessages) > 0 { + var messages []string + for _, message := range systemMessages { + messages = append(messages, message.Text) + } + + body["system"] = strings.Join(messages, " ") + } + + return body, nil + + case "mistral.mistral-large-2402-v1:0": + fallthrough + case "mistral.mistral-large-2407-v1:0": + var bedrockMessages []BedrockMistralChatMessage + for _, msg := range messages { + var filteredToolCalls []BedrockMistralToolCall + if msg.ToolCalls != nil { + for _, toolCall := range *msg.ToolCalls { + filteredToolCalls = append(filteredToolCalls, BedrockMistralToolCall{ + ID: toolCall.ID, + Function: *toolCall.Function, + }) + } + } + + message := BedrockMistralChatMessage{ + Role: msg.Role, + Content: []BedrockMistralContent{ + {Text: *msg.Content}, + }, + } + + if len(filteredToolCalls) > 0 { + message.ToolCalls = &filteredToolCalls + } + + bedrockMessages = append(bedrockMessages, message) + } + + body := map[string]interface{}{ + "messages": bedrockMessages, + } + + return body, nil + } + + return nil, fmt.Errorf("invalid model choice: %s", model) +} + +func (p *BedrockProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { + startTime := time.Now() + + preparedParams := PrepareParams(params) + + requestBody := MergeConfig(map[string]interface{}{ + "prompt": text, + }, preparedParams) + + // Marshal the request body + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + // Create the signed request with correct operation name + req, err := p.PrepareReq(fmt.Sprintf("%s/invoke", model), jsonData, key) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + // Execute the request + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %v", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bedrock API error: %s", string(body)) + } + + result, err := p.GetTextCompletionResult(body, model) + if err != nil { + return nil, fmt.Errorf("failed to parse response body: %v", err) + } + // Calculate latency + latency := time.Since(startTime).Seconds() + result.Usage.Latency = &latency + + return result, nil +} + +func (p *BedrockProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { + messageBody, err := p.PrepareChatCompletionMessages(messages, model) + if err != nil { + return nil, fmt.Errorf("error preparing messages: %v", err) + } + + preparedParams := PrepareParams(params) + requestBody := MergeConfig(messageBody, preparedParams) + + // Marshal the request body + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + // Format the path with proper model identifier + path := fmt.Sprintf("%s/converse", model) + + if p.meta != nil && p.meta.InferenceProfiles != nil { + if inferenceProfileId, ok := p.meta.InferenceProfiles[model]; ok { + if p.meta.ARN != nil { + encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *p.meta.ARN, inferenceProfileId)) + path = fmt.Sprintf("%s/converse", encodedModelIdentifier) + } + } + } + + // Create the signed request + req, err := p.PrepareReq(path, jsonData, key) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + // Execute the request + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %v", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bedrock API error: %s", string(body)) + } + + var response BedrockChatResponse + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) + } + + var choices []interfaces.CompletionResultChoice + for i, choice := range response.Output.Message.Content { + choices = append(choices, interfaces.CompletionResultChoice{ + Index: i, + Message: interfaces.CompletionResponseChoice{ + Role: interfaces.RoleAssistant, + Content: choice.Text, + }, + StopReason: &response.StopReason, + }) + } + + latency := float64(response.Metrics.Latency) + + result := &interfaces.CompletionResult{ + Choices: choices, + Usage: interfaces.LLMUsage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.TotalTokens, + Latency: &latency, + }, + Model: model, + Provider: interfaces.Bedrock, + } + + return result, nil +} diff --git a/providers/openai.go b/providers/openai.go index 9bcdef267..b97a86dcd 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -62,7 +62,6 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int // Format messages for OpenAI API var openAIMessages []map[string]interface{} for _, msg := range messages { - var content any if msg.Content != nil { content = msg.Content diff --git a/providers/utils.go b/providers/utils.go index bd925de75..56c590808 100644 --- a/providers/utils.go +++ b/providers/utils.go @@ -2,7 +2,19 @@ package providers import ( "bifrost/interfaces" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" "reflect" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" ) // MergeConfig merges default config with custom parameters @@ -63,3 +75,60 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} { return flatParams } + +func SignAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken *string, region, service string) error { + // Set required headers before signing + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + // Calculate SHA256 hash of the request body + var bodyHash string + if req.Body != nil { + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("failed to read request body: %v", err) + } + // Restore the body for subsequent reads + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + hash := sha256.Sum256(bodyBytes) + bodyHash = hex.EncodeToString(hash[:]) + } else { + // For empty body, use the hash of an empty string + hash := sha256.Sum256([]byte{}) + bodyHash = hex.EncodeToString(hash[:]) + } + + cfg, err := config.LoadDefaultConfig(context.TODO(), + config.WithRegion(region), + config.WithCredentialsProvider(aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) { + creds := aws.Credentials{ + AccessKeyID: accessKey, + SecretAccessKey: secretKey, + } + if sessionToken != nil { + creds.SessionToken = *sessionToken + } + return creds, nil + })), + ) + if err != nil { + return fmt.Errorf("failed to load AWS config: %v", err) + } + + // Create the AWS signer + signer := v4.NewSigner() + + // Get credentials + creds, err := cfg.Credentials.Retrieve(context.TODO()) + if err != nil { + return fmt.Errorf("failed to retrieve credentials: %v", err) + } + + // Sign the request with AWS Signature V4 + if err := signer.SignHTTP(context.TODO(), creds, req, bodyHash, service, region, time.Now()); err != nil { + return fmt.Errorf("failed to sign request: %v", err) + } + + return nil +} diff --git a/tests/account.go b/tests/account.go index 3af3543f1..be0d14c58 100644 --- a/tests/account.go +++ b/tests/account.go @@ -11,7 +11,7 @@ type BaseAccount struct{} // GetInitiallyConfiguredProviderKeys returns all provider keys func (baseAccount *BaseAccount) GetInitiallyConfiguredProviderKeys() ([]interfaces.SupportedModelProvider, error) { - return []interfaces.SupportedModelProvider{interfaces.OpenAI, interfaces.Anthropic}, nil + return []interfaces.SupportedModelProvider{interfaces.OpenAI, interfaces.Anthropic, interfaces.Bedrock}, nil } // GetKeysForProvider returns all keys associated with a provider @@ -33,6 +33,14 @@ func (baseAccount *BaseAccount) GetKeysForProvider(provider interfaces.Provider) Weight: 1.0, }, }, nil + case interfaces.Bedrock: + return []interfaces.Key{ + { + Value: os.Getenv("BEDROCK_API_KEY"), + Models: []string{"anthropic.claude-v2:1", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"}, + Weight: 1.0, + }, + }, nil default: return nil, fmt.Errorf("unsupported provider: %s", provider.GetProviderKey()) } @@ -51,6 +59,11 @@ func (baseAccount *BaseAccount) GetConcurrencyAndBufferSizeForProvider(provider Concurrency: 3, BufferSize: 10, }, nil + case interfaces.Bedrock: + return &interfaces.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, nil default: return nil, fmt.Errorf("unsupported provider: %s", provider.GetProviderKey()) } diff --git a/tests/anthropic_test.go b/tests/anthropic_test.go index 9cad8fdb3..f4790f8b7 100644 --- a/tests/anthropic_test.go +++ b/tests/anthropic_test.go @@ -31,7 +31,7 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { result, err := bifrost.TextCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{ Model: "claude-2.1", Input: interfaces.RequestInput{ - StringInput: &text, + TextInput: &text, }, Params: ¶ms, }, ctx) @@ -54,14 +54,14 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { time.Sleep(delay) messages := []interfaces.Message{ { - Role: interfaces.UserRole, + Role: interfaces.RoleUser, Content: &msg, }, } result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{ Model: "claude-3-7-sonnet-20250219", Input: interfaces.RequestInput{ - MessageInput: &messages, + ChatInput: &messages, }, Params: ¶ms, }, ctx) diff --git a/tests/bedrock_test.go b/tests/bedrock_test.go new file mode 100644 index 000000000..4fb3ba53c --- /dev/null +++ b/tests/bedrock_test.go @@ -0,0 +1,88 @@ +package tests + +import ( + "bifrost" + "bifrost/interfaces" + "context" + "fmt" + "testing" + "time" +) + +// setupBedrockRequests sends multiple test requests to Bedrock +func setupBedrockRequests(bifrost *bifrost.Bifrost) { + bedrockMessages := []string{ + "What's your favorite programming language?", + "Can you help me write a Go function?", + "What's the best way to learn programming?", + "Tell me about artificial intelligence.", + } + + ctx := context.Background() + + go func() { + params := interfaces.ModelParameters{ + ExtraParams: map[string]interface{}{ + "max_tokens_to_sample": 4096, + }, + } + text := "\n\nHuman:\n\nAssistant:" + + result, err := bifrost.TextCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{ + Model: "anthropic.claude-v2:1", + Input: interfaces.RequestInput{ + TextInput: &text, + }, + Params: ¶ms, + }, ctx) + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Println("🤖 Text Completion Result:", result.Choices[0].Message.Content) + } + }() + + params := interfaces.ModelParameters{ + ExtraParams: map[string]interface{}{ + "max_tokens": 4096, + }, + } + + for i, message := range bedrockMessages { + delay := time.Duration(500+100*i) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{ + Model: "anthropic.claude-3-sonnet-20240229-v1:0", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + + if err != nil { + fmt.Printf("Error in Bedrock request %d: %v\n", index+1, err) + } else { + fmt.Printf("🤖 Chat Completion Result %d: %s\n", index+1, result.Choices[0].Message.Content) + } + }(message, delay, i) + } +} + +func TestBedrock(t *testing.T) { + bifrost, err := getBifrost() + if err != nil { + t.Fatalf("Error initializing bifrost: %v", err) + return + } + + setupBedrockRequests(bifrost) + + bifrost.Cleanup() +} diff --git a/tests/openai_test.go b/tests/openai_test.go index 7ef30a29d..246f1a411 100644 --- a/tests/openai_test.go +++ b/tests/openai_test.go @@ -20,7 +20,7 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { result, err := bifrost.TextCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ Model: "gpt-4o-mini", Input: interfaces.RequestInput{ - StringInput: &text, + TextInput: &text, }, Params: nil, }, ctx) @@ -45,14 +45,14 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { time.Sleep(delay) messages := []interfaces.Message{ { - Role: interfaces.UserRole, + Role: interfaces.RoleUser, Content: &msg, }, } result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ Model: "gpt-4o-mini", Input: interfaces.RequestInput{ - MessageInput: &messages, + ChatInput: &messages, }, Params: nil, }, ctx) diff --git a/tests/setup.go b/tests/setup.go index a31f6300a..7362cc54e 100644 --- a/tests/setup.go +++ b/tests/setup.go @@ -55,6 +55,17 @@ func getBifrost() (*bifrost.Bifrost, error) { DefaultRequestTimeoutInSeconds: 30, }, }, + interfaces.Bedrock: { + NetworkConfig: interfaces.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + MetaConfig: &interfaces.MetaConfig{ + BedrockMetaConfig: &interfaces.BedrockMetaConfig{ + SecretAccessKey: "AMpq95pNadM2fD1GlcNvjbMiGhizwYaGKJxv+nti", + Region: maxim.StrPtr("us-east-1"), + }, + }, + }, } bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}, configs)