diff --git a/core/go.mod b/core/go.mod index e264aa050..85dd3be3b 100644 --- a/core/go.mod +++ b/core/go.mod @@ -3,6 +3,8 @@ module github.com/maximhq/bifrost/core go 1.25.5 require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 github.com/aws/aws-sdk-go-v2 v1.41.0 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 github.com/aws/aws-sdk-go-v2/config v1.32.5 @@ -22,6 +24,8 @@ require ( require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect @@ -42,12 +46,15 @@ require ( github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/klauspost/compress v1.18.1 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/mailru/easyjson v0.9.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/spf13/cast v1.10.0 // indirect @@ -56,8 +63,8 @@ require ( github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/arch v0.22.0 // indirect + golang.org/x/crypto v0.45.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sys v0.38.0 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/core/go.sum b/core/go.sum index b2daa6f7e..d2526513c 100644 --- a/core/go.sum +++ b/core/go.sum @@ -1,5 +1,15 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -60,6 +70,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8Yc github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -69,17 +81,17 @@ github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRS github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mark3labs/mcp-go v0.41.1 h1:w78eWfiQam2i8ICL7AL0WFiq7KHNJQ6UB53ZVtH4KGA= @@ -91,6 +103,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= @@ -124,12 +138,14 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= diff --git a/core/internal/testutil/account.go b/core/internal/testutil/account.go index a39ba3dba..27fa7c419 100644 --- a/core/internal/testutil/account.go +++ b/core/internal/testutil/account.go @@ -219,6 +219,9 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context "o1": "o1", "text-embedding-ada-002": "text-embedding-ada-002", }, + ClientID: bifrost.Ptr(os.Getenv("AZURE_CLIENT_ID")), + ClientSecret: bifrost.Ptr(os.Getenv("AZURE_CLIENT_SECRET")), + TenantID: bifrost.Ptr(os.Getenv("AZURE_TENANT_ID")), }, UseForBatchAPI: bifrost.Ptr(true), }, diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index e332767a1..95504a5bd 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -11,8 +11,12 @@ import ( "mime/multipart" "net/http" "net/url" + "sync" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/providers/anthropic" "github.com/maximhq/bifrost/core/providers/openai" @@ -25,13 +29,88 @@ import ( // AzureAuthorizationTokenKey is the context key for the Azure authentication token. const AzureAuthorizationTokenKey schemas.BifrostContextKey = "azure-authorization-token" +// DefaultAzureScope is the default scope for Azure authentication. +const DefaultAzureScope = "https://cognitiveservices.azure.com/.default" + // AzureProvider implements the Provider interface for Azure's API. type AzureProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests - networkConfig schemas.NetworkConfig // Network configuration including extra headers - sendBackRawRequest bool // Whether to include raw request in BifrostResponse - sendBackRawResponse bool // Whether to include raw response in BifrostResponse + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers + + credentials sync.Map // map of tenant ID:client ID to azcore.TokenCredential + sendBackRawRequest bool // Whether to include raw request in BifrostResponse + sendBackRawResponse bool // Whether to include raw response in BifrostResponse +} + +func (p *AzureProvider) getOrCreateAuth( + tenantID, clientID, clientSecret string, +) (azcore.TokenCredential, error) { + key := tenantID + ":" + clientID + + // Fast path + if val, ok := p.credentials.Load(key); ok { + return val.(azcore.TokenCredential), nil + } + + // Slow path - create new credential + cred, err := azidentity.NewClientSecretCredential( + tenantID, + clientID, + clientSecret, + nil, + ) + if err != nil { + return nil, err + } + + actual, _ := p.credentials.LoadOrStore(key, cred) + return actual.(azcore.TokenCredential), nil +} + +// getAzureAuthHeaders returns authentication headers based on priority: +// 1. Service Principal (client ID/secret/tenant ID) - Bearer token +// 2. Context token - Bearer token +// 3. API key - api-key or x-api-key header +func (provider *AzureProvider) getAzureAuthHeaders(ctx context.Context, key schemas.Key, isAnthropicModel bool) (map[string]string, *schemas.BifrostError) { + authHeader := make(map[string]string) + + // Service Principal authentication + if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientID != nil && + key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil { + cred, err := provider.getOrCreateAuth(*key.AzureKeyConfig.TenantID, *key.AzureKeyConfig.ClientID, *key.AzureKeyConfig.ClientSecret) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + } + + token, err := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{DefaultAzureScope}, + }) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + } + + if token.Token == "" { + return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty"), schemas.Azure) + } + + authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token) + return authHeader, nil + } + + // Context token authentication + if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok && authToken != "" { + authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken) + return authHeader, nil + } + + // API key authentication + if isAnthropicModel { + authHeader["x-api-key"] = key.Value + } else { + authHeader["api-key"] = key.Value + } + return authHeader, nil } // NewAzureProvider creates a new Azure provider instance. @@ -89,19 +168,23 @@ func (provider *AzureProvider) completeRequest( req.Header.SetContentType("application/json") var url string - if schemas.IsAnthropicModel(deployment) { - req.Header.Set("x-api-key", key.Value) + isAnthropicModel := schemas.IsAnthropicModel(deployment) + + // Get authentication headers + authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, isAnthropicModel) + if bifrostErr != nil { + return nil, deployment, 0, bifrostErr + } + + // Apply headers to request + for k, v := range authHeaders { + req.Header.Set(k, v) + } + + if isAnthropicModel { req.Header.Set("anthropic-version", AzureAnthropicAPIVersionDefault) url = fmt.Sprintf("%s/%s", key.AzureKeyConfig.Endpoint, path) } else { - if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { - // TODO: Shift this to key.Value like in bedrock and vertex - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) - // Ensure api-key is not accidentally present (from extra headers, etc.) - req.Header.Del("api-key") - } else { - req.Header.Set("api-key", key.Value) - } apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.Ptr(AzureAPIVersionDefault) @@ -170,13 +253,13 @@ func (provider *AzureProvider) listModelsByKey(ctx context.Context, key schemas. req.Header.SetMethod(http.MethodGet) req.Header.SetContentType("application/json") - // Set Azure authentication - either Bearer token or api-key - if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) - // Ensure api-key is not accidentally present (from extra headers, etc.) - req.Header.Del("api-key") - } else { - req.Header.Set("api-key", key.Value) + // Set Azure authentication + authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, false) + if bifrostErr != nil { + return nil, bifrostErr + } + for k, v := range authHeaders { + req.Header.Set(k, v) } // Send the request and measure latency @@ -322,14 +405,10 @@ func (provider *AzureProvider) TextCompletionStream(ctx context.Context, postHoo url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint, deployment, *apiVersion) - // Prepare Azure-specific headers - authHeader := make(map[string]string) - - // Set Azure authentication - either Bearer token or api-key - if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { - authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken) - } else { - authHeader["api-key"] = key.Value + // Get Azure authentication headers + authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) + if err != nil { + return nil, err } customPostResponseConverter := func(response *schemas.BifrostTextCompletionResponse) *schemas.BifrostTextCompletionResponse { @@ -465,10 +544,12 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo return response } - authHeader := make(map[string]string) var url string if schemas.IsAnthropicModel(deployment) { - authHeader["x-api-key"] = key.Value + authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) + if err != nil { + return nil, err + } authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint) @@ -512,11 +593,9 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo }, ) } else { - // Set Azure authentication - either Bearer token or api-key - if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { - authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken) - } else { - authHeader["api-key"] = key.Value + authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) + if err != nil { + return nil, err } apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -652,10 +731,12 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn return response } - authHeader := make(map[string]string) var url string if schemas.IsAnthropicModel(deployment) { - authHeader["x-api-key"] = key.Value + authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) + if err != nil { + return nil, err + } authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint) @@ -685,11 +766,9 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn }, ) } else { - // Set Azure authentication - either Bearer token or api-key - if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { - authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken) - } else { - authHeader["api-key"] = key.Value + authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) + if err != nil { + return nil, err } url = fmt.Sprintf("%s/openai/v1/responses?api-version=preview", key.AzureKeyConfig.Endpoint) @@ -832,19 +911,17 @@ func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner return nil, err } - authHeader := make(map[string]string) - var url string - // Set Azure authentication - either Bearer token or api-key - if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { - authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken) - } else { - authHeader["api-key"] = key.Value + // Get Azure authentication headers + authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) + if err != nil { + return nil, err } + apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.Ptr(AzureAPIVersionDefault) } - url = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint, deployment, *apiVersion) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint, deployment, *apiVersion) // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1228,7 +1305,9 @@ func (provider *AzureProvider) FileUpload(ctx context.Context, key schemas.Key, req.Header.SetContentType(writer.FormDataContentType()) // Set Azure authentication - provider.setAzureAuth(ctx, req, key) + if err := provider.setAzureAuth(ctx, req, key); err != nil { + return nil, err + } req.SetBody(buf.Bytes()) @@ -1333,7 +1412,9 @@ func (provider *AzureProvider) FileList(ctx context.Context, keys []schemas.Key, req.Header.SetContentType("application/json") // Set Azure authentication - provider.setAzureAuth(ctx, req, key) + if err := provider.setAzureAuth(ctx, req, key); err != nil { + return nil, err + } // Make request latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) @@ -1437,7 +1518,12 @@ func (provider *AzureProvider) FileRetrieve(ctx context.Context, keys []schemas. req.Header.SetContentType("application/json") // Set Azure authentication - provider.setAzureAuth(ctx, req, key) + if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil { + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + lastErr = authErr + continue + } // Make request latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) @@ -1528,7 +1614,12 @@ func (provider *AzureProvider) FileDelete(ctx context.Context, keys []schemas.Ke req.Header.SetContentType("application/json") // Set Azure authentication - provider.setAzureAuth(ctx, req, key) + if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil { + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + lastErr = authErr + continue + } // Make request latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) @@ -1650,7 +1741,12 @@ func (provider *AzureProvider) FileContent(ctx context.Context, keys []schemas.K req.Header.SetMethod(http.MethodGet) // Set Azure authentication - provider.setAzureAuth(ctx, req, key) + if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil { + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + lastErr = authErr + continue + } // Make request latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) @@ -1766,7 +1862,9 @@ func (provider *AzureProvider) BatchCreate(ctx context.Context, key schemas.Key, req.Header.SetContentType("application/json") // Set Azure authentication - provider.setAzureAuth(ctx, req, key) + if err := provider.setAzureAuth(ctx, req, key); err != nil { + return nil, err + } // Build request body openAIReq := &openai.OpenAIBatchRequest{ @@ -1885,7 +1983,9 @@ func (provider *AzureProvider) BatchList(ctx context.Context, keys []schemas.Key req.Header.SetContentType("application/json") // Set Azure authentication - provider.setAzureAuth(ctx, req, key) + if err := provider.setAzureAuth(ctx, req, key); err != nil { + return nil, err + } // Make request latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) @@ -1984,7 +2084,12 @@ func (provider *AzureProvider) BatchRetrieve(ctx context.Context, keys []schemas req.Header.SetContentType("application/json") // Set Azure authentication - provider.setAzureAuth(ctx, req, key) + if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil { + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + lastErr = authErr + continue + } // Make request latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) @@ -2077,7 +2182,12 @@ func (provider *AzureProvider) BatchCancel(ctx context.Context, keys []schemas.K req.Header.SetContentType("application/json") // Set Azure authentication - provider.setAzureAuth(ctx, req, key) + if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil { + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + lastErr = authErr + continue + } // Make request latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) diff --git a/core/providers/azure/files.go b/core/providers/azure/files.go index 88c03e218..d059f8d1b 100644 --- a/core/providers/azure/files.go +++ b/core/providers/azure/files.go @@ -5,20 +5,54 @@ import ( "fmt" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/maximhq/bifrost/core/providers/openai" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) -// setAzureAuth sets the Azure authentication header on the request. -func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.Request, key schemas.Key) { +// setAzureAuth sets the Azure authentication header on the request for OpenAI models. +// It handles three authentication methods in order of priority: +// 1. Service Principal (client ID/secret/tenant ID) - uses Bearer token +// 2. Context token - uses Bearer token +// 3. API key - uses api-key header +func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.Request, key schemas.Key) *schemas.BifrostError { + // Service Principal authentication + if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientID != nil && + key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil { + cred, err := provider.getOrCreateAuth(*key.AzureKeyConfig.TenantID, *key.AzureKeyConfig.ClientID, *key.AzureKeyConfig.ClientSecret) + if err != nil { + return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + } + + token, err := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{DefaultAzureScope}, + }) + if err != nil { + return providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + } + + if token.Token == "" { + return providerUtils.NewBifrostOperationError("Azure access token is empty", fmt.Errorf("token is empty"), schemas.Azure) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token)) + req.Header.Del("api-key") + return nil + } + + // Context token authentication if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok && authToken != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) req.Header.Del("api-key") - } else { - req.Header.Del("Authorization") - req.Header.Set("api-key", key.Value) + return nil } + + // API key authentication + req.Header.Del("Authorization") + req.Header.Set("api-key", key.Value) + return nil } // AzureFileResponse represents an Azure file response (same as OpenAI). diff --git a/core/schemas/account.go b/core/schemas/account.go index 808a7586b..788efa277 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -24,6 +24,10 @@ type AzureKeyConfig struct { Endpoint string `json:"endpoint"` // Azure service endpoint URL Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names APIVersion *string `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" + + ClientID *string `json:"client_id,omitempty"` // Azure client ID for authentication + ClientSecret *string `json:"client_secret,omitempty"` // Azure client secret for authentication + TenantID *string `json:"tenant_id,omitempty"` // Azure tenant ID for authentication } // VertexKeyConfig represents the Vertex-specific configuration. diff --git a/docs/quickstart/go-sdk/provider-configuration.mdx b/docs/quickstart/go-sdk/provider-configuration.mdx index b3a281126..8945c8008 100644 --- a/docs/quickstart/go-sdk/provider-configuration.mdx +++ b/docs/quickstart/go-sdk/provider-configuration.mdx @@ -346,7 +346,42 @@ Enterprise cloud providers require additional configuration beyond API keys. Con -Azure requires endpoint URLs, deployment mappings, and API version configuration: +Azure supports two authentication methods: + +**Azure Entra ID (Service Principal)** + +Production-ready authentication with automatic token management. When `ClientID`, `ClientSecret`, and `TenantID` are set, Bifrost handles token acquisition using the default scope `https://cognitiveservices.azure.com/.default`. + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: "", // Leave empty for Service Principal auth + Models: []string{"gpt-4o", "gpt-4o-mini"}, + Weight: 1.0, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: os.Getenv("AZURE_ENDPOINT"), + ClientID: bifrost.Ptr(os.Getenv("AZURE_CLIENT_ID")), + ClientSecret: bifrost.Ptr(os.Getenv("AZURE_CLIENT_SECRET")), + TenantID: bifrost.Ptr(os.Getenv("AZURE_TENANT_ID")), + Deployments: map[string]string{ + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment", + }, + APIVersion: bifrost.Ptr("2024-08-01-preview"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + +**Direct Authentication** + +For simpler use cases, provide the authentication credential directly in the `Value` field: ```go func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { @@ -354,16 +389,16 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo case schemas.Azure: return []schemas.Key{ { - Value: os.Getenv("AZURE_API_KEY"), + Value: os.Getenv("AZURE_OPENAI_KEY"), Models: []string{"gpt-4o", "gpt-4o-mini"}, Weight: 1.0, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: os.Getenv("AZURE_ENDPOINT"), // e.g., "https://your-resource.openai.azure.com" + Endpoint: os.Getenv("AZURE_ENDPOINT"), Deployments: map[string]string{ "gpt-4o": "gpt-4o-deployment", "gpt-4o-mini": "gpt-4o-mini-deployment", }, - APIVersion: bifrost.Ptr("2024-08-01-preview"), // Azure API version + APIVersion: bifrost.Ptr("2024-08-01-preview"), }, }, }, nil @@ -372,6 +407,10 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo } ``` + +If `ClientID`, `ClientSecret`, and `TenantID` are configured, Service Principal authentication is used. Otherwise, direct authentication with the `Value` field is used. + + diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 93975ff16..8fe31397c 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -110,6 +110,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddUseForBatchAPIColumnAndS3BucketsConfig(ctx, db); err != nil { return err } + if err := migrationAddAzureClientIDAndClientSecretAndTenantIDColumns(ctx, db); err != nil { + return err + } return nil } @@ -1827,3 +1830,45 @@ func migrationAddUseForBatchAPIColumnAndS3BucketsConfig(ctx context.Context, db } return nil } + +// migrationAddAzureClientIDAndClientSecretAndTenantIDColumns adds the azure_client_id, azure_client_secret, and azure_tenant_id columns to the key table +func migrationAddAzureClientIDAndClientSecretAndTenantIDColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_azure_client_id_and_client_secret_and_tenant_id_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableKey{}, "azure_client_id") { + if err := migrator.AddColumn(&tables.TableKey{}, "azure_client_id"); err != nil { + return fmt.Errorf("failed to add azure_client_id column: %w", err) + } + } + if !migrator.HasColumn(&tables.TableKey{}, "azure_client_secret") { + if err := migrator.AddColumn(&tables.TableKey{}, "azure_client_secret"); err != nil { + return fmt.Errorf("failed to add azure_client_secret column: %w", err) + } + } + if !migrator.HasColumn(&tables.TableKey{}, "azure_tenant_id") { + if err := migrator.AddColumn(&tables.TableKey{}, "azure_tenant_id"); err != nil { + return fmt.Errorf("failed to add azure_tenant_id column: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableKey{}, "azure_client_id"); err != nil { + return fmt.Errorf("failed to drop azure_client_id column: %w", err) + } + if err := migrator.DropColumn(&tables.TableKey{}, "azure_client_secret"); err != nil { + return fmt.Errorf("failed to drop azure_client_secret column: %w", err) + } + if err := migrator.DropColumn(&tables.TableKey{}, "azure_tenant_id"); err != nil { + return fmt.Errorf("failed to drop azure_tenant_id column: %w", err) + } + return nil + }, + }}) + return m.Migrate() +} diff --git a/framework/configstore/tables/key.go b/framework/configstore/tables/key.go index ede86a175..9ac3643cc 100644 --- a/framework/configstore/tables/key.go +++ b/framework/configstore/tables/key.go @@ -30,6 +30,9 @@ type TableKey struct { AzureEndpoint *string `gorm:"type:text" json:"azure_endpoint,omitempty"` AzureAPIVersion *string `gorm:"type:varchar(50)" json:"azure_api_version,omitempty"` AzureDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + AzureClientID *string `gorm:"type:varchar(255)" json:"azure_client_id,omitempty"` + AzureClientSecret *string `gorm:"type:text" json:"azure_client_secret,omitempty"` + AzureTenantID *string `gorm:"type:varchar(255)" json:"azure_tenant_id,omitempty"` // Vertex config fields (embedded) VertexProjectID *string `gorm:"type:varchar(255)" json:"vertex_project_id,omitempty"` @@ -39,13 +42,13 @@ type TableKey struct { VertexDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string // Bedrock config fields (embedded) - BedrockAccessKey *string `gorm:"type:varchar(255)" json:"bedrock_access_key,omitempty"` - BedrockSecretKey *string `gorm:"type:text" json:"bedrock_secret_key,omitempty"` - BedrockSessionToken *string `gorm:"type:text" json:"bedrock_session_token,omitempty"` - BedrockRegion *string `gorm:"type:varchar(100)" json:"bedrock_region,omitempty"` - BedrockARN *string `gorm:"type:text" json:"bedrock_arn,omitempty"` - BedrockDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config + BedrockAccessKey *string `gorm:"type:varchar(255)" json:"bedrock_access_key,omitempty"` + BedrockSecretKey *string `gorm:"type:text" json:"bedrock_secret_key,omitempty"` + BedrockSessionToken *string `gorm:"type:text" json:"bedrock_session_token,omitempty"` + BedrockRegion *string `gorm:"type:varchar(100)" json:"bedrock_region,omitempty"` + BedrockARN *string `gorm:"type:text" json:"bedrock_arn,omitempty"` + BedrockDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config // Batch API configuration UseForBatchAPI *bool `gorm:"default:false" json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations @@ -84,6 +87,9 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { k.AzureEndpoint = nil } k.AzureAPIVersion = k.AzureKeyConfig.APIVersion + k.AzureClientID = k.AzureKeyConfig.ClientID + k.AzureClientSecret = k.AzureKeyConfig.ClientSecret + k.AzureTenantID = k.AzureKeyConfig.TenantID if k.AzureKeyConfig.Deployments != nil { data, err := json.Marshal(k.AzureKeyConfig.Deployments) if err != nil { @@ -98,6 +104,9 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { k.AzureEndpoint = nil k.AzureAPIVersion = nil k.AzureDeploymentsJSON = nil + k.AzureClientID = nil + k.AzureClientSecret = nil + k.AzureTenantID = nil } if k.VertexKeyConfig != nil { @@ -202,8 +211,11 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { // Reconstruct Azure config if fields are present if k.AzureEndpoint != nil { azureConfig := &schemas.AzureKeyConfig{ - Endpoint: "", - APIVersion: k.AzureAPIVersion, + Endpoint: "", + APIVersion: k.AzureAPIVersion, + ClientID: k.AzureClientID, + ClientSecret: k.AzureClientSecret, + TenantID: k.AzureTenantID, } if k.AzureEndpoint != nil { diff --git a/framework/configstore/tables/virtualkey.go b/framework/configstore/tables/virtualkey.go index ddb7b1fe8..e3d4d438b 100644 --- a/framework/configstore/tables/virtualkey.go +++ b/framework/configstore/tables/virtualkey.go @@ -53,6 +53,9 @@ func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error { // Clear all Azure-related sensitive fields key.AzureEndpoint = nil key.AzureAPIVersion = nil + key.AzureClientID = nil + key.AzureClientSecret = nil + key.AzureTenantID = nil key.AzureDeploymentsJSON = nil key.AzureKeyConfig = nil diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 0ebcfe957..63811098d 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -764,6 +764,31 @@ func (h *ProviderHandler) mergeKeys(provider schemas.ModelProvider, oldRawKeys [ mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion } } + // handle client id and secret and tenant id + if updateKey.AzureKeyConfig.ClientID != nil && + oldRedactedKey.AzureKeyConfig.ClientID != nil && + oldRawKey.AzureKeyConfig != nil { + if lib.IsRedacted(*updateKey.AzureKeyConfig.ClientID) && + strings.EqualFold(*updateKey.AzureKeyConfig.ClientID, *oldRedactedKey.AzureKeyConfig.ClientID) { + mergedKey.AzureKeyConfig.ClientID = oldRawKey.AzureKeyConfig.ClientID + } + } + if updateKey.AzureKeyConfig.ClientSecret != nil && + oldRedactedKey.AzureKeyConfig.ClientSecret != nil && + oldRawKey.AzureKeyConfig != nil { + if lib.IsRedacted(*updateKey.AzureKeyConfig.ClientSecret) && + strings.EqualFold(*updateKey.AzureKeyConfig.ClientSecret, *oldRedactedKey.AzureKeyConfig.ClientSecret) { + mergedKey.AzureKeyConfig.ClientSecret = oldRawKey.AzureKeyConfig.ClientSecret + } + } + if updateKey.AzureKeyConfig.TenantID != nil && + oldRedactedKey.AzureKeyConfig.TenantID != nil && + oldRawKey.AzureKeyConfig != nil { + if lib.IsRedacted(*updateKey.AzureKeyConfig.TenantID) && + strings.EqualFold(*updateKey.AzureKeyConfig.TenantID, *oldRedactedKey.AzureKeyConfig.TenantID) { + mergedKey.AzureKeyConfig.TenantID = oldRawKey.AzureKeyConfig.TenantID + } + } } // Handle Vertex config redacted values diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index b5a0fff8b..d9f786a5c 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -2166,7 +2166,34 @@ func (c *Config) GetProviderConfigRedacted(provider schemas.ModelProvider) (*con azureConfig.APIVersion = key.AzureKeyConfig.APIVersion } } + // Redact ClientID if present + if key.AzureKeyConfig.ClientID != nil { + path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.client_id", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + azureConfig.ClientID = bifrost.Ptr("env." + envVar) + } else { + azureConfig.ClientID = bifrost.Ptr(RedactKey(*key.AzureKeyConfig.ClientID)) + } + } + // Redact ClientSecret if present + if key.AzureKeyConfig.ClientSecret != nil { + path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.client_secret", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + azureConfig.ClientSecret = bifrost.Ptr("env." + envVar) + } else if !strings.HasPrefix(*key.AzureKeyConfig.ClientSecret, "env.") { + azureConfig.ClientSecret = bifrost.Ptr(RedactKey(*key.AzureKeyConfig.ClientSecret)) + } + } + // Redact TenantID if present + if key.AzureKeyConfig.TenantID != nil { + path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.tenant_id", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + azureConfig.TenantID = bifrost.Ptr("env." + envVar) + } else { + azureConfig.TenantID = bifrost.Ptr(RedactKey(*key.AzureKeyConfig.TenantID)) + } + } redactedConfig.Keys[i].AzureKeyConfig = azureConfig } @@ -3264,6 +3291,18 @@ func (c *Config) getFieldValue(key schemas.Key, fieldName string) string { if key.AzureKeyConfig != nil && key.AzureKeyConfig.APIVersion != nil { return *key.AzureKeyConfig.APIVersion } + case "client_id": + if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientID != nil { + return *key.AzureKeyConfig.ClientID + } + case "client_secret": + if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientSecret != nil { + return *key.AzureKeyConfig.ClientSecret + } + case "tenant_id": + if key.AzureKeyConfig != nil && key.AzureKeyConfig.TenantID != nil { + return *key.AzureKeyConfig.TenantID + } case "access_key": if key.BedrockKeyConfig != nil { return key.BedrockKeyConfig.AccessKey @@ -3397,6 +3436,63 @@ func (c *Config) processAzureKeyConfigEnvVars(key *schemas.Key, provider schemas azureConfig.APIVersion = &processedAPIVersion } + // Process ClientID if present + if azureConfig.ClientID != nil { + processedClientID, envVar, err := c.processEnvValue(*azureConfig.ClientID) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + c.EnvKeys[envVar] = append(c.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "azure_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.client_id", provider, key.ID), + KeyID: key.ID, + }) + } + azureConfig.ClientID = &processedClientID + } + + // Process ClientSecret if present + if azureConfig.ClientSecret != nil { + processedClientSecret, envVar, err := c.processEnvValue(*azureConfig.ClientSecret) + if err != nil { + return err + } + azureConfig.ClientSecret = &processedClientSecret + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + c.EnvKeys[envVar] = append(c.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "azure_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.client_secret", provider, key.ID), + KeyID: key.ID, + }) + } + azureConfig.ClientSecret = &processedClientSecret + } + + // Process TenantID if present + if azureConfig.TenantID != nil { + processedTenantID, envVar, err := c.processEnvValue(*azureConfig.TenantID) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + c.EnvKeys[envVar] = append(c.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: provider, + KeyType: "azure_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.tenant_id", provider, key.ID), + KeyID: key.ID, + }) + } + azureConfig.TenantID = &processedTenantID + } return nil } diff --git a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx index db5253bce..1dcf03f4d 100644 --- a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx @@ -206,6 +206,54 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { )} /> + + + + + Azure Entra ID Authentication + + To use Azure Entra ID authentication, fill in Client ID, Client Secret, and Tenant ID. Please leave API Key empty when using Entra ID authentication. + + + ( + + Client ID (Optional) + + + + + + )} + /> + ( + + Client Secret (Optional) + + + + + + )} + /> + ( + + Tenant ID (Optional) + + + + + + )} + /> !value || isValidDeployments(value), { message: "Valid Deployments (JSON object) are required for Azure keys" }), api_version: z.string().optional(), + client_id: z.string().optional(), + client_secret: z.string().optional(), + tenant_id: z.string().optional(), }); const VertexKeyConfigSchema = z.object({ diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index 9d616464a..25053cf32 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -24,12 +24,18 @@ export interface AzureKeyConfig { endpoint: string; deployments?: Record | string; // Allow string during editing api_version?: string; + client_id?: string; + client_secret?: string; + tenant_id?: string; } export const DefaultAzureKeyConfig: AzureKeyConfig = { endpoint: "", deployments: {}, api_version: "2024-02-01", + client_id: "", + client_secret: "", + tenant_id: "", } as const satisfies Required; // VertexKeyConfig matching Go's schemas.VertexKeyConfig diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 60e9644fe..d44d59e47 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -18,6 +18,9 @@ export const azureKeyConfigSchema = z endpoint: z.url("Must be a valid URL"), deployments: z.union([z.record(z.string(), z.string()), z.string()]).optional(), api_version: z.string().optional(), + client_id: z.string().optional(), + client_secret: z.string().optional(), + tenant_id: z.string().optional(), }) .refine( (data) => {