Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions interfaces/meta/bedrock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package meta

type BedrockMetaConfig struct {
SecretAccessKey *string `json:"secret_access_key,omitempty"`
Region *string `json:"region,omitempty"`
SessionToken *string `json:"session_token,omitempty"`
ARN *string `json:"arn,omitempty"`
InferenceProfiles map[string]string `json:"inference_profiles,omitempty"`
}

func (c *BedrockMetaConfig) GetSecretAccessKey() *string {
return c.SecretAccessKey
}

func (c *BedrockMetaConfig) GetRegion() *string {
return c.Region
}

func (c *BedrockMetaConfig) GetSessionToken() *string {
return c.SessionToken
}

func (c *BedrockMetaConfig) GetARN() *string {
return c.ARN
}

func (c *BedrockMetaConfig) GetInferenceProfiles() map[string]string {
return c.InferenceProfiles
}
14 changes: 7 additions & 7 deletions interfaces/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ type NetworkConfig struct {
RetryBackoffMax time.Duration `json:"retry_backoff_max"`
}

type MetaConfig struct {
SecretAccessKey *string `json:"secret_access_key,omitempty"`
Region *string `json:"region,omitempty"`
SessionToken *string `json:"session_token,omitempty"`
ARN *string `json:"arn,omitempty"`
InferenceProfiles map[string]string `json:"inference_profiles,omitempty"`
type MetaConfig interface {
GetSecretAccessKey() *string
GetRegion() *string
GetSessionToken() *string
GetARN() *string
GetInferenceProfiles() map[string]string
}

type ConcurrencyAndBufferSize struct {
Expand Down Expand Up @@ -44,7 +44,7 @@ type ProxyConfig struct {

type ProviderConfig struct {
NetworkConfig NetworkConfig `json:"network_config"`
MetaConfig *MetaConfig `json:"meta_config,omitempty"`
MetaConfig MetaConfig `json:"meta_config,omitempty"`
ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"`
Logger Logger `json:"logger"`
ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"`
Expand Down
16 changes: 8 additions & 8 deletions providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ type BedrockError struct {

type BedrockProvider struct {
client *http.Client
meta *interfaces.MetaConfig
meta interfaces.MetaConfig
}

func NewBedrockProvider(config *interfaces.ProviderConfig) *BedrockProvider {
Expand All @@ -128,8 +128,8 @@ func (provider *BedrockProvider) CompleteRequest(requestBody map[string]interfac
}

region := "us-east-1"
if provider.meta.Region != nil {
region = *provider.meta.Region
if provider.meta.GetRegion() != nil {
region = *provider.meta.GetRegion()
}

jsonBody, err := json.Marshal(requestBody)
Expand All @@ -155,7 +155,7 @@ func (provider *BedrockProvider) CompleteRequest(requestBody map[string]interfac
}
}

if err := SignAWSRequest(req, accessKey, *provider.meta.SecretAccessKey, provider.meta.SessionToken, region, "bedrock"); err != nil {
if err := SignAWSRequest(req, accessKey, *provider.meta.GetSecretAccessKey(), provider.meta.GetSessionToken(), region, "bedrock"); err != nil {
return nil, err
}

Expand Down Expand Up @@ -521,10 +521,10 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []in
// Format the path with proper model identifier
path := fmt.Sprintf("%s/converse", model)

if provider.meta != nil && provider.meta.InferenceProfiles != nil {
if inferenceProfileId, ok := provider.meta.InferenceProfiles[model]; ok {
if provider.meta.ARN != nil {
encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.ARN, inferenceProfileId))
if provider.meta != nil && provider.meta.GetInferenceProfiles() != nil {
if inferenceProfileId, ok := provider.meta.GetInferenceProfiles()[model]; ok {
if provider.meta.GetARN() != nil {
encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.GetARN(), inferenceProfileId))
path = fmt.Sprintf("%s/converse", encodedModelIdentifier)
}
}
Expand Down
3 changes: 2 additions & 1 deletion tests/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/maximhq/bifrost/interfaces"
"github.com/maximhq/bifrost/interfaces/meta"

"github.com/maximhq/maxim-go"
)
Expand Down Expand Up @@ -95,7 +96,7 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.Supp
RetryBackoffInitial: 100 * time.Millisecond,
RetryBackoffMax: 2 * time.Second,
},
MetaConfig: &interfaces.MetaConfig{
MetaConfig: &meta.BedrockMetaConfig{
SecretAccessKey: maxim.StrPtr(os.Getenv("BEDROCK_ACCESS_KEY")),
Region: maxim.StrPtr("us-east-1"),
},
Expand Down