diff --git a/go.mod b/go.mod index ef8d638aa..1b4fa96a3 100644 --- a/go.mod +++ b/go.mod @@ -28,4 +28,6 @@ require ( github.com/aws/smithy-go v1.22.2 // indirect github.com/klauspost/compress v1.17.11 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/net v0.31.0 // indirect + golang.org/x/text v0.20.0 // indirect ) diff --git a/go.sum b/go.sum index 06db2183d..0993f024c 100644 --- a/go.sum +++ b/go.sum @@ -38,3 +38,7 @@ github.com/valyala/fasthttp v1.58.0 h1:GGB2dWxSbEprU9j0iMJHgdKYJVDyjrOwF9RE59PbR github.com/valyala/fasthttp v1.58.0/go.mod h1:SYXvHHaFp7QZHGKSHmoMipInhrI5StHrhDTYVEjK/Kw= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= diff --git a/interfaces/provider.go b/interfaces/provider.go index bcbf2f7e1..34d51e725 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -29,11 +29,20 @@ const ( EnvProxy ProxyType = "environment" ) +// ProxyConfig holds proxy configuration +type ProxyConfig struct { + Type ProxyType `json:"type"` // Type of proxy (none, http, socks5, environment) + URL string `json:"url"` // Proxy URL (for http and socks5) + Username string `json:"username"` // Optional username for proxy authentication + Password string `json:"password"` // Optional password for proxy authentication +} + type ProviderConfig struct { NetworkConfig NetworkConfig `json:"network_config"` MetaConfig *MetaConfig `json:"meta_config,omitempty"` ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` Logger Logger `json:"logger"` + ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` } // Provider defines the interface for AI model providers diff --git a/providers/anthropic.go b/providers/anthropic.go index 15a58d783..1d41f45d3 100644 --- a/providers/anthropic.go +++ b/providers/anthropic.go @@ -65,13 +65,18 @@ type AnthropicProvider struct { // NewAnthropicProvider creates a new AnthropicProvider instance func NewAnthropicProvider(config *interfaces.ProviderConfig, logger interfaces.Logger) *AnthropicProvider { + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + return &AnthropicProvider{ logger: logger, - client: &fasthttp.Client{ - ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, - }, + client: client, } } diff --git a/providers/openai.go b/providers/openai.go index 89ff5de6b..eb3bc15c8 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -18,7 +18,6 @@ var ( ErrOpenAIDecodeStructured = fmt.Errorf("error decoding OpenAI structured response") ErrOpenAIDecodeRaw = fmt.Errorf("error decoding OpenAI raw response") ErrOpenAIDecompress = fmt.Errorf("error decompressing OpenAI response") - ErrOpenAIProxyConfig = fmt.Errorf("invalid proxy configuration") ) // OpenAIResponsePool provides a pool for OpenAI response objects @@ -107,6 +106,9 @@ func NewOpenAIProvider(config *interfaces.ProviderConfig, logger interfaces.Logg bifrostResponsePool.Put(&interfaces.BifrostResponse{}) } + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + return &OpenAIProvider{ logger: logger, client: client, diff --git a/providers/utils.go b/providers/utils.go index b8ca0ec13..7b1323c0c 100644 --- a/providers/utils.go +++ b/providers/utils.go @@ -5,13 +5,17 @@ import ( "context" "crypto/sha256" "encoding/hex" + "fmt" "io" "net/http" + "net/url" "reflect" "strings" "time" "github.com/maximhq/bifrost/interfaces" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttpproxy" "maps" @@ -160,3 +164,54 @@ func SignAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken return nil } + +// configureProxy sets up the proxy for the fasthttp client +func configureProxy(client *fasthttp.Client, proxyConfig *interfaces.ProxyConfig, logger interfaces.Logger) *fasthttp.Client { + if proxyConfig == nil { + return client + } + + var dialFunc fasthttp.DialFunc + + // Create the appropriate proxy based on type + switch proxyConfig.Type { + case interfaces.NoProxy: + return client + case interfaces.HttpProxy: + if proxyConfig.URL == "" { + logger.Warn("Warning: HTTP proxy URL is required for setting up proxy") + return client + } + dialFunc = fasthttpproxy.FasthttpHTTPDialer(proxyConfig.URL) + case interfaces.Socks5Proxy: + if proxyConfig.URL == "" { + logger.Warn("Warning: SOCKS5 proxy URL is required for setting up proxy") + return client + } + proxyUrl := proxyConfig.URL + // Add authentication if provided + if proxyConfig.Username != "" && proxyConfig.Password != "" { + parsedURL, err := url.Parse(proxyConfig.URL) + if err != nil { + logger.Warn("Invalid proxy configuration: invalid SOCKS5 proxy URL") + return client + } + // Set user and password in the parsed URL + parsedURL.User = url.UserPassword(proxyConfig.Username, proxyConfig.Password) + proxyUrl = parsedURL.String() + } + dialFunc = fasthttpproxy.FasthttpSocksDialer(proxyUrl) + case interfaces.EnvProxy: + // Use environment variables for proxy configuration + dialFunc = fasthttpproxy.FasthttpProxyHTTPDialer() + default: + logger.Warn(fmt.Sprintf("Invalid proxy configuration: unsupported proxy type: %s", proxyConfig.Type)) + return client + } + + if dialFunc != nil { + client.Dial = dialFunc + } + + return client +}