diff --git a/interfaces/meta/bedrock.go b/interfaces/meta/bedrock.go new file mode 100644 index 000000000..c4caf56a8 --- /dev/null +++ b/interfaces/meta/bedrock.go @@ -0,0 +1,29 @@ +package meta + +type BedrockMetaConfig struct { + SecretAccessKey *string `json:"secret_access_key,omitempty"` + Region *string `json:"region,omitempty"` + SessionToken *string `json:"session_token,omitempty"` + ARN *string `json:"arn,omitempty"` + InferenceProfiles map[string]string `json:"inference_profiles,omitempty"` +} + +func (c *BedrockMetaConfig) GetSecretAccessKey() *string { + return c.SecretAccessKey +} + +func (c *BedrockMetaConfig) GetRegion() *string { + return c.Region +} + +func (c *BedrockMetaConfig) GetSessionToken() *string { + return c.SessionToken +} + +func (c *BedrockMetaConfig) GetARN() *string { + return c.ARN +} + +func (c *BedrockMetaConfig) GetInferenceProfiles() map[string]string { + return c.InferenceProfiles +} diff --git a/interfaces/provider.go b/interfaces/provider.go index 8550f89fc..20d54c353 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -11,12 +11,12 @@ type NetworkConfig struct { RetryBackoffMax time.Duration `json:"retry_backoff_max"` } -type MetaConfig struct { - SecretAccessKey *string `json:"secret_access_key,omitempty"` - Region *string `json:"region,omitempty"` - SessionToken *string `json:"session_token,omitempty"` - ARN *string `json:"arn,omitempty"` - InferenceProfiles map[string]string `json:"inference_profiles,omitempty"` +type MetaConfig interface { + GetSecretAccessKey() *string + GetRegion() *string + GetSessionToken() *string + GetARN() *string + GetInferenceProfiles() map[string]string } type ConcurrencyAndBufferSize struct { @@ -44,7 +44,7 @@ type ProxyConfig struct { type ProviderConfig struct { NetworkConfig NetworkConfig `json:"network_config"` - MetaConfig *MetaConfig `json:"meta_config,omitempty"` + MetaConfig MetaConfig `json:"meta_config,omitempty"` ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` Logger Logger `json:"logger"` ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` diff --git a/providers/bedrock.go b/providers/bedrock.go index 890b4f41b..1c94af694 100644 --- a/providers/bedrock.go +++ b/providers/bedrock.go @@ -103,7 +103,7 @@ type BedrockError struct { type BedrockProvider struct { client *http.Client - meta *interfaces.MetaConfig + meta interfaces.MetaConfig } func NewBedrockProvider(config *interfaces.ProviderConfig) *BedrockProvider { @@ -128,8 +128,8 @@ func (provider *BedrockProvider) CompleteRequest(requestBody map[string]interfac } region := "us-east-1" - if provider.meta.Region != nil { - region = *provider.meta.Region + if provider.meta.GetRegion() != nil { + region = *provider.meta.GetRegion() } jsonBody, err := json.Marshal(requestBody) @@ -155,7 +155,7 @@ func (provider *BedrockProvider) CompleteRequest(requestBody map[string]interfac } } - if err := SignAWSRequest(req, accessKey, *provider.meta.SecretAccessKey, provider.meta.SessionToken, region, "bedrock"); err != nil { + if err := SignAWSRequest(req, accessKey, *provider.meta.GetSecretAccessKey(), provider.meta.GetSessionToken(), region, "bedrock"); err != nil { return nil, err } @@ -521,10 +521,10 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []in // Format the path with proper model identifier path := fmt.Sprintf("%s/converse", model) - if provider.meta != nil && provider.meta.InferenceProfiles != nil { - if inferenceProfileId, ok := provider.meta.InferenceProfiles[model]; ok { - if provider.meta.ARN != nil { - encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.ARN, inferenceProfileId)) + if provider.meta != nil && provider.meta.GetInferenceProfiles() != nil { + if inferenceProfileId, ok := provider.meta.GetInferenceProfiles()[model]; ok { + if provider.meta.GetARN() != nil { + encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.GetARN(), inferenceProfileId)) path = fmt.Sprintf("%s/converse", encodedModelIdentifier) } } diff --git a/tests/account.go b/tests/account.go index f4df02a2a..ce1ecc257 100644 --- a/tests/account.go +++ b/tests/account.go @@ -6,6 +6,7 @@ import ( "time" "github.com/maximhq/bifrost/interfaces" + "github.com/maximhq/bifrost/interfaces/meta" "github.com/maximhq/maxim-go" ) @@ -95,7 +96,7 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.Supp RetryBackoffInitial: 100 * time.Millisecond, RetryBackoffMax: 2 * time.Second, }, - MetaConfig: &interfaces.MetaConfig{ + MetaConfig: &meta.BedrockMetaConfig{ SecretAccessKey: maxim.StrPtr(os.Getenv("BEDROCK_ACCESS_KEY")), Region: maxim.StrPtr("us-east-1"), },