diff --git a/bifrost.go b/bifrost.go index 64cb0f36a..a3e41e767 100644 --- a/bifrost.go +++ b/bifrost.go @@ -32,23 +32,32 @@ type Bifrost struct { account interfaces.Account providers []interfaces.Provider // list of processed providers plugins []interfaces.Plugin + configs map[interfaces.SupportedModelProvider]interfaces.ProviderConfig requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues wg map[interfaces.SupportedModelProvider]*sync.WaitGroup } -func createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider) (interfaces.Provider, error) { +func createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) (interfaces.Provider, error) { switch providerKey { case interfaces.OpenAI: - return providers.NewOpenAIProvider(), nil + return providers.NewOpenAIProvider(config), nil case interfaces.Anthropic: - return providers.NewAnthropicProvider(), nil + return providers.NewAnthropicProvider(config), nil default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } } -func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelProvider) error { - provider, err := createProviderFromProviderKey(providerKey) +func getConfigForProvider(providerKey interfaces.SupportedModelProvider, configs map[interfaces.SupportedModelProvider]interfaces.ProviderConfig) (*interfaces.ProviderConfig, error) { + if config, ok := configs[providerKey]; ok { + return &config, nil + } + + return nil, fmt.Errorf("no config found for provider: %s", providerKey) +} + +func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) error { + provider, err := createProviderFromProviderKey(providerKey, config) if err != nil { return fmt.Errorf("failed to get provider for the given key: %v", err) } @@ -80,7 +89,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro } // Initializes infinite listening channels for each provider -func Init(account interfaces.Account, plugins []interfaces.Plugin) (*Bifrost, error) { +func Init(account interfaces.Account, plugins []interfaces.Plugin, configs map[interfaces.SupportedModelProvider]interfaces.ProviderConfig) (*Bifrost, error) { bifrost := &Bifrost{account: account, plugins: plugins} bifrost.wg = make(map[interfaces.SupportedModelProvider]*sync.WaitGroup) @@ -93,7 +102,12 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin) (*Bifrost, er // Create buffered channels for each provider and start workers for _, providerKey := range providerKeys { - if err := bifrost.prepareProvider(providerKey); err != nil { + config, err := getConfigForProvider(providerKey, configs) + if err != nil { + return nil, fmt.Errorf("failed to get config for provider: %v", err) + } + + if err := bifrost.prepareProvider(providerKey, config); err != nil { fmt.Printf("failed to prepare provider: %v", err) } } @@ -200,7 +214,12 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr var exists bool if queue, exists = bifrost.requestQueues[providerKey]; !exists { - if err := bifrost.prepareProvider(providerKey); err != nil { + config, err := getConfigForProvider(providerKey, bifrost.configs) + if err != nil { + return nil, fmt.Errorf("failed to get config for provider: %v", err) + } + + if err := bifrost.prepareProvider(providerKey, config); err != nil { return nil, err } diff --git a/interfaces/provider.go b/interfaces/provider.go index 36d7ac094..2eb6313ee 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -202,6 +202,27 @@ type ImageContent struct { // return nil // } +type NetworkConfig struct { + DefaultRequestTimeoutInSeconds int `json:"defaultRequestTimeoutInSeconds"` +} + +type MetaConfig struct { + BedrockMetaConfig *BedrockMetaConfig `json:"bedrockMetaConfig"` +} + +type ProviderConfig struct { + NetworkConfig NetworkConfig `json:"networkConfig"` + MetaConfig *MetaConfig `json:"metaConfig"` +} + +type BedrockMetaConfig struct { + SecretAccessKey string `json:"secretAccessKey"` + Region *string `json:"region"` + SessionToken *string `json:"sessionToken"` + ARN *string `json:"arn"` + InferenceProfiles map[string]string `json:"inferenceProfiles"` +} + // Provider defines the interface for AI model providers type Provider interface { GetProviderKey() SupportedModelProvider diff --git a/providers/anthropic.go b/providers/anthropic.go index 2ba210787..d181a4d19 100644 --- a/providers/anthropic.go +++ b/providers/anthropic.go @@ -50,10 +50,9 @@ type AnthropicProvider struct { } // NewAnthropicProvider creates a new AnthropicProvider instance -func NewAnthropicProvider() *AnthropicProvider { +func NewAnthropicProvider(config *interfaces.ProviderConfig) *AnthropicProvider { return &AnthropicProvider{ - // @comment let us have this be controllable - client: &http.Client{Timeout: 30 * time.Second}, + client: &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)}, } } diff --git a/providers/openai.go b/providers/openai.go index 2c1099aeb..9bcdef267 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -23,9 +23,9 @@ type OpenAIProvider struct { } // NewOpenAIProvider creates a new OpenAI provider instance -func NewOpenAIProvider() *OpenAIProvider { +func NewOpenAIProvider(config *interfaces.ProviderConfig) *OpenAIProvider { return &OpenAIProvider{ - client: &http.Client{Timeout: time.Second * 30}, + client: &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)}, } } diff --git a/tests/setup.go b/tests/setup.go index d088ad4f4..a31f6300a 100644 --- a/tests/setup.go +++ b/tests/setup.go @@ -44,7 +44,20 @@ func getBifrost() (*bifrost.Bifrost, error) { return nil, err } - bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}) + configs := map[interfaces.SupportedModelProvider]interfaces.ProviderConfig{ + interfaces.OpenAI: { + NetworkConfig: interfaces.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + }, + interfaces.Anthropic: { + NetworkConfig: interfaces.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + }, + } + + bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}, configs) if err != nil { return nil, err }