diff --git a/core/go.mod b/core/go.mod
index e264aa050..85dd3be3b 100644
--- a/core/go.mod
+++ b/core/go.mod
@@ -3,6 +3,8 @@ module github.com/maximhq/bifrost/core
go 1.25.5
require (
+ github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
+ github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1
github.com/aws/aws-sdk-go-v2 v1.41.0
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4
github.com/aws/aws-sdk-go-v2/config v1.32.5
@@ -22,6 +24,8 @@ require (
require (
cloud.google.com/go/compute/metadata v0.9.0 // indirect
+ github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
+ github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect
@@ -42,12 +46,15 @@ require (
github.com/bytedance/sonic/loader v0.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
+ github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/klauspost/compress v1.18.1 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
+ github.com/kylelemons/godebug v1.1.0 // indirect
github.com/mailru/easyjson v0.9.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/spf13/cast v1.10.0 // indirect
@@ -56,8 +63,8 @@ require (
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/arch v0.22.0 // indirect
+ golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
- gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/core/go.sum b/core/go.sum
index b2daa6f7e..d2526513c 100644
--- a/core/go.sum
+++ b/core/go.sum
@@ -1,5 +1,15 @@
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw=
+github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
+github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
+github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY=
+github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
+github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
+github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM=
+github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
+github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4=
@@ -60,6 +70,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8Yc
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
+github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
+github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -69,17 +81,17 @@ github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRS
github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
+github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
-github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
+github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
+github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8=
github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mark3labs/mcp-go v0.41.1 h1:w78eWfiQam2i8ICL7AL0WFiq7KHNJQ6UB53ZVtH4KGA=
@@ -91,6 +103,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
+github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
@@ -124,12 +138,14 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
+golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
diff --git a/core/internal/testutil/account.go b/core/internal/testutil/account.go
index a39ba3dba..27fa7c419 100644
--- a/core/internal/testutil/account.go
+++ b/core/internal/testutil/account.go
@@ -219,6 +219,9 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context
"o1": "o1",
"text-embedding-ada-002": "text-embedding-ada-002",
},
+ ClientID: bifrost.Ptr(os.Getenv("AZURE_CLIENT_ID")),
+ ClientSecret: bifrost.Ptr(os.Getenv("AZURE_CLIENT_SECRET")),
+ TenantID: bifrost.Ptr(os.Getenv("AZURE_TENANT_ID")),
},
UseForBatchAPI: bifrost.Ptr(true),
},
diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go
index e332767a1..95504a5bd 100644
--- a/core/providers/azure/azure.go
+++ b/core/providers/azure/azure.go
@@ -11,8 +11,12 @@ import (
"mime/multipart"
"net/http"
"net/url"
+ "sync"
"time"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
+ "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/openai"
@@ -25,13 +29,88 @@ import (
// AzureAuthorizationTokenKey is the context key for the Azure authentication token.
const AzureAuthorizationTokenKey schemas.BifrostContextKey = "azure-authorization-token"
+// DefaultAzureScope is the default scope for Azure authentication.
+const DefaultAzureScope = "https://cognitiveservices.azure.com/.default"
+
// AzureProvider implements the Provider interface for Azure's API.
type AzureProvider struct {
- logger schemas.Logger // Logger for provider operations
- client *fasthttp.Client // HTTP client for API requests
- networkConfig schemas.NetworkConfig // Network configuration including extra headers
- sendBackRawRequest bool // Whether to include raw request in BifrostResponse
- sendBackRawResponse bool // Whether to include raw response in BifrostResponse
+ logger schemas.Logger // Logger for provider operations
+ client *fasthttp.Client // HTTP client for API requests
+ networkConfig schemas.NetworkConfig // Network configuration including extra headers
+
+ credentials sync.Map // map of tenant ID:client ID to azcore.TokenCredential
+ sendBackRawRequest bool // Whether to include raw request in BifrostResponse
+ sendBackRawResponse bool // Whether to include raw response in BifrostResponse
+}
+
+func (p *AzureProvider) getOrCreateAuth(
+ tenantID, clientID, clientSecret string,
+) (azcore.TokenCredential, error) {
+ key := tenantID + ":" + clientID
+
+ // Fast path
+ if val, ok := p.credentials.Load(key); ok {
+ return val.(azcore.TokenCredential), nil
+ }
+
+ // Slow path - create new credential
+ cred, err := azidentity.NewClientSecretCredential(
+ tenantID,
+ clientID,
+ clientSecret,
+ nil,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ actual, _ := p.credentials.LoadOrStore(key, cred)
+ return actual.(azcore.TokenCredential), nil
+}
+
+// getAzureAuthHeaders returns authentication headers based on priority:
+// 1. Service Principal (client ID/secret/tenant ID) - Bearer token
+// 2. Context token - Bearer token
+// 3. API key - api-key or x-api-key header
+func (provider *AzureProvider) getAzureAuthHeaders(ctx context.Context, key schemas.Key, isAnthropicModel bool) (map[string]string, *schemas.BifrostError) {
+ authHeader := make(map[string]string)
+
+ // Service Principal authentication
+ if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientID != nil &&
+ key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil {
+ cred, err := provider.getOrCreateAuth(*key.AzureKeyConfig.TenantID, *key.AzureKeyConfig.ClientID, *key.AzureKeyConfig.ClientSecret)
+ if err != nil {
+ return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure)
+ }
+
+ token, err := cred.GetToken(ctx, policy.TokenRequestOptions{
+ Scopes: []string{DefaultAzureScope},
+ })
+ if err != nil {
+ return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure)
+ }
+
+ if token.Token == "" {
+ return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty"), schemas.Azure)
+ }
+
+ authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token)
+ return authHeader, nil
+ }
+
+ // Context token authentication
+ if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok && authToken != "" {
+ authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken)
+ return authHeader, nil
+ }
+
+ // API key authentication
+ if isAnthropicModel {
+ authHeader["x-api-key"] = key.Value
+ } else {
+ authHeader["api-key"] = key.Value
+ }
+ return authHeader, nil
}
// NewAzureProvider creates a new Azure provider instance.
@@ -89,19 +168,23 @@ func (provider *AzureProvider) completeRequest(
req.Header.SetContentType("application/json")
var url string
- if schemas.IsAnthropicModel(deployment) {
- req.Header.Set("x-api-key", key.Value)
+ isAnthropicModel := schemas.IsAnthropicModel(deployment)
+
+ // Get authentication headers
+ authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, isAnthropicModel)
+ if bifrostErr != nil {
+ return nil, deployment, 0, bifrostErr
+ }
+
+ // Apply headers to request
+ for k, v := range authHeaders {
+ req.Header.Set(k, v)
+ }
+
+ if isAnthropicModel {
req.Header.Set("anthropic-version", AzureAnthropicAPIVersionDefault)
url = fmt.Sprintf("%s/%s", key.AzureKeyConfig.Endpoint, path)
} else {
- if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok {
- // TODO: Shift this to key.Value like in bedrock and vertex
- req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))
- // Ensure api-key is not accidentally present (from extra headers, etc.)
- req.Header.Del("api-key")
- } else {
- req.Header.Set("api-key", key.Value)
- }
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.Ptr(AzureAPIVersionDefault)
@@ -170,13 +253,13 @@ func (provider *AzureProvider) listModelsByKey(ctx context.Context, key schemas.
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
- // Set Azure authentication - either Bearer token or api-key
- if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok {
- req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))
- // Ensure api-key is not accidentally present (from extra headers, etc.)
- req.Header.Del("api-key")
- } else {
- req.Header.Set("api-key", key.Value)
+ // Set Azure authentication
+ authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, false)
+ if bifrostErr != nil {
+ return nil, bifrostErr
+ }
+ for k, v := range authHeaders {
+ req.Header.Set(k, v)
}
// Send the request and measure latency
@@ -322,14 +405,10 @@ func (provider *AzureProvider) TextCompletionStream(ctx context.Context, postHoo
url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint, deployment, *apiVersion)
- // Prepare Azure-specific headers
- authHeader := make(map[string]string)
-
- // Set Azure authentication - either Bearer token or api-key
- if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok {
- authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken)
- } else {
- authHeader["api-key"] = key.Value
+ // Get Azure authentication headers
+ authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
+ if err != nil {
+ return nil, err
}
customPostResponseConverter := func(response *schemas.BifrostTextCompletionResponse) *schemas.BifrostTextCompletionResponse {
@@ -465,10 +544,12 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
return response
}
- authHeader := make(map[string]string)
var url string
if schemas.IsAnthropicModel(deployment) {
- authHeader["x-api-key"] = key.Value
+ authHeader, err := provider.getAzureAuthHeaders(ctx, key, true)
+ if err != nil {
+ return nil, err
+ }
authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault
url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint)
@@ -512,11 +593,9 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
},
)
} else {
- // Set Azure authentication - either Bearer token or api-key
- if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok {
- authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken)
- } else {
- authHeader["api-key"] = key.Value
+ authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
+ if err != nil {
+ return nil, err
}
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
@@ -652,10 +731,12 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn
return response
}
- authHeader := make(map[string]string)
var url string
if schemas.IsAnthropicModel(deployment) {
- authHeader["x-api-key"] = key.Value
+ authHeader, err := provider.getAzureAuthHeaders(ctx, key, true)
+ if err != nil {
+ return nil, err
+ }
authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault
url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint)
@@ -685,11 +766,9 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn
},
)
} else {
- // Set Azure authentication - either Bearer token or api-key
- if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok {
- authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken)
- } else {
- authHeader["api-key"] = key.Value
+ authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
+ if err != nil {
+ return nil, err
}
url = fmt.Sprintf("%s/openai/v1/responses?api-version=preview", key.AzureKeyConfig.Endpoint)
@@ -832,19 +911,17 @@ func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner
return nil, err
}
- authHeader := make(map[string]string)
- var url string
- // Set Azure authentication - either Bearer token or api-key
- if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok {
- authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken)
- } else {
- authHeader["api-key"] = key.Value
+ // Get Azure authentication headers
+ authHeader, err := provider.getAzureAuthHeaders(ctx, key, false)
+ if err != nil {
+ return nil, err
}
+
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = schemas.Ptr(AzureAPIVersionDefault)
}
- url = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint, deployment, *apiVersion)
+ url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint, deployment, *apiVersion)
// Create HTTP request for streaming
req := fasthttp.AcquireRequest()
@@ -1228,7 +1305,9 @@ func (provider *AzureProvider) FileUpload(ctx context.Context, key schemas.Key,
req.Header.SetContentType(writer.FormDataContentType())
// Set Azure authentication
- provider.setAzureAuth(ctx, req, key)
+ if err := provider.setAzureAuth(ctx, req, key); err != nil {
+ return nil, err
+ }
req.SetBody(buf.Bytes())
@@ -1333,7 +1412,9 @@ func (provider *AzureProvider) FileList(ctx context.Context, keys []schemas.Key,
req.Header.SetContentType("application/json")
// Set Azure authentication
- provider.setAzureAuth(ctx, req, key)
+ if err := provider.setAzureAuth(ctx, req, key); err != nil {
+ return nil, err
+ }
// Make request
latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
@@ -1437,7 +1518,12 @@ func (provider *AzureProvider) FileRetrieve(ctx context.Context, keys []schemas.
req.Header.SetContentType("application/json")
// Set Azure authentication
- provider.setAzureAuth(ctx, req, key)
+ if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
+ fasthttp.ReleaseRequest(req)
+ fasthttp.ReleaseResponse(resp)
+ lastErr = authErr
+ continue
+ }
// Make request
latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
@@ -1528,7 +1614,12 @@ func (provider *AzureProvider) FileDelete(ctx context.Context, keys []schemas.Ke
req.Header.SetContentType("application/json")
// Set Azure authentication
- provider.setAzureAuth(ctx, req, key)
+ if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
+ fasthttp.ReleaseRequest(req)
+ fasthttp.ReleaseResponse(resp)
+ lastErr = authErr
+ continue
+ }
// Make request
latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
@@ -1650,7 +1741,12 @@ func (provider *AzureProvider) FileContent(ctx context.Context, keys []schemas.K
req.Header.SetMethod(http.MethodGet)
// Set Azure authentication
- provider.setAzureAuth(ctx, req, key)
+ if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
+ fasthttp.ReleaseRequest(req)
+ fasthttp.ReleaseResponse(resp)
+ lastErr = authErr
+ continue
+ }
// Make request
latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
@@ -1766,7 +1862,9 @@ func (provider *AzureProvider) BatchCreate(ctx context.Context, key schemas.Key,
req.Header.SetContentType("application/json")
// Set Azure authentication
- provider.setAzureAuth(ctx, req, key)
+ if err := provider.setAzureAuth(ctx, req, key); err != nil {
+ return nil, err
+ }
// Build request body
openAIReq := &openai.OpenAIBatchRequest{
@@ -1885,7 +1983,9 @@ func (provider *AzureProvider) BatchList(ctx context.Context, keys []schemas.Key
req.Header.SetContentType("application/json")
// Set Azure authentication
- provider.setAzureAuth(ctx, req, key)
+ if err := provider.setAzureAuth(ctx, req, key); err != nil {
+ return nil, err
+ }
// Make request
latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
@@ -1984,7 +2084,12 @@ func (provider *AzureProvider) BatchRetrieve(ctx context.Context, keys []schemas
req.Header.SetContentType("application/json")
// Set Azure authentication
- provider.setAzureAuth(ctx, req, key)
+ if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
+ fasthttp.ReleaseRequest(req)
+ fasthttp.ReleaseResponse(resp)
+ lastErr = authErr
+ continue
+ }
// Make request
latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
@@ -2077,7 +2182,12 @@ func (provider *AzureProvider) BatchCancel(ctx context.Context, keys []schemas.K
req.Header.SetContentType("application/json")
// Set Azure authentication
- provider.setAzureAuth(ctx, req, key)
+ if authErr := provider.setAzureAuth(ctx, req, key); authErr != nil {
+ fasthttp.ReleaseRequest(req)
+ fasthttp.ReleaseResponse(resp)
+ lastErr = authErr
+ continue
+ }
// Make request
latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
diff --git a/core/providers/azure/files.go b/core/providers/azure/files.go
index 88c03e218..d059f8d1b 100644
--- a/core/providers/azure/files.go
+++ b/core/providers/azure/files.go
@@ -5,20 +5,54 @@ import (
"fmt"
"time"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/maximhq/bifrost/core/providers/openai"
+ providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
-// setAzureAuth sets the Azure authentication header on the request.
-func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.Request, key schemas.Key) {
+// setAzureAuth sets the Azure authentication header on the request for OpenAI models.
+// It handles three authentication methods in order of priority:
+// 1. Service Principal (client ID/secret/tenant ID) - uses Bearer token
+// 2. Context token - uses Bearer token
+// 3. API key - uses api-key header
+func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.Request, key schemas.Key) *schemas.BifrostError {
+ // Service Principal authentication
+ if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientID != nil &&
+ key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil {
+ cred, err := provider.getOrCreateAuth(*key.AzureKeyConfig.TenantID, *key.AzureKeyConfig.ClientID, *key.AzureKeyConfig.ClientSecret)
+ if err != nil {
+ return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure)
+ }
+
+ token, err := cred.GetToken(ctx, policy.TokenRequestOptions{
+ Scopes: []string{DefaultAzureScope},
+ })
+ if err != nil {
+ return providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure)
+ }
+
+ if token.Token == "" {
+ return providerUtils.NewBifrostOperationError("Azure access token is empty", fmt.Errorf("token is empty"), schemas.Azure)
+ }
+
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token))
+ req.Header.Del("api-key")
+ return nil
+ }
+
+ // Context token authentication
if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok && authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))
req.Header.Del("api-key")
- } else {
- req.Header.Del("Authorization")
- req.Header.Set("api-key", key.Value)
+ return nil
}
+
+ // API key authentication
+ req.Header.Del("Authorization")
+ req.Header.Set("api-key", key.Value)
+ return nil
}
// AzureFileResponse represents an Azure file response (same as OpenAI).
diff --git a/core/schemas/account.go b/core/schemas/account.go
index 808a7586b..788efa277 100644
--- a/core/schemas/account.go
+++ b/core/schemas/account.go
@@ -24,6 +24,10 @@ type AzureKeyConfig struct {
Endpoint string `json:"endpoint"` // Azure service endpoint URL
Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names
APIVersion *string `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21"
+
+ ClientID *string `json:"client_id,omitempty"` // Azure client ID for authentication
+ ClientSecret *string `json:"client_secret,omitempty"` // Azure client secret for authentication
+ TenantID *string `json:"tenant_id,omitempty"` // Azure tenant ID for authentication
}
// VertexKeyConfig represents the Vertex-specific configuration.
diff --git a/docs/quickstart/go-sdk/provider-configuration.mdx b/docs/quickstart/go-sdk/provider-configuration.mdx
index b3a281126..8945c8008 100644
--- a/docs/quickstart/go-sdk/provider-configuration.mdx
+++ b/docs/quickstart/go-sdk/provider-configuration.mdx
@@ -346,7 +346,42 @@ Enterprise cloud providers require additional configuration beyond API keys. Con
-Azure requires endpoint URLs, deployment mappings, and API version configuration:
+Azure supports two authentication methods:
+
+**Azure Entra ID (Service Principal)**
+
+Production-ready authentication with automatic token management. When `ClientID`, `ClientSecret`, and `TenantID` are set, Bifrost handles token acquisition using the default scope `https://cognitiveservices.azure.com/.default`.
+
+```go
+func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
+ switch provider {
+ case schemas.Azure:
+ return []schemas.Key{
+ {
+ Value: "", // Leave empty for Service Principal auth
+ Models: []string{"gpt-4o", "gpt-4o-mini"},
+ Weight: 1.0,
+ AzureKeyConfig: &schemas.AzureKeyConfig{
+ Endpoint: os.Getenv("AZURE_ENDPOINT"),
+ ClientID: bifrost.Ptr(os.Getenv("AZURE_CLIENT_ID")),
+ ClientSecret: bifrost.Ptr(os.Getenv("AZURE_CLIENT_SECRET")),
+ TenantID: bifrost.Ptr(os.Getenv("AZURE_TENANT_ID")),
+ Deployments: map[string]string{
+ "gpt-4o": "gpt-4o-deployment",
+ "gpt-4o-mini": "gpt-4o-mini-deployment",
+ },
+ APIVersion: bifrost.Ptr("2024-08-01-preview"),
+ },
+ },
+ }, nil
+ }
+ return nil, fmt.Errorf("provider %s not supported", provider)
+}
+```
+
+**Direct Authentication**
+
+For simpler use cases, provide the authentication credential directly in the `Value` field:
```go
func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
@@ -354,16 +389,16 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo
case schemas.Azure:
return []schemas.Key{
{
- Value: os.Getenv("AZURE_API_KEY"),
+ Value: os.Getenv("AZURE_OPENAI_KEY"),
Models: []string{"gpt-4o", "gpt-4o-mini"},
Weight: 1.0,
AzureKeyConfig: &schemas.AzureKeyConfig{
- Endpoint: os.Getenv("AZURE_ENDPOINT"), // e.g., "https://your-resource.openai.azure.com"
+ Endpoint: os.Getenv("AZURE_ENDPOINT"),
Deployments: map[string]string{
"gpt-4o": "gpt-4o-deployment",
"gpt-4o-mini": "gpt-4o-mini-deployment",
},
- APIVersion: bifrost.Ptr("2024-08-01-preview"), // Azure API version
+ APIVersion: bifrost.Ptr("2024-08-01-preview"),
},
},
}, nil
@@ -372,6 +407,10 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo
}
```
+
+If `ClientID`, `ClientSecret`, and `TenantID` are configured, Service Principal authentication is used. Otherwise, direct authentication with the `Value` field is used.
+
+
diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go
index 93975ff16..8fe31397c 100644
--- a/framework/configstore/migrations.go
+++ b/framework/configstore/migrations.go
@@ -110,6 +110,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error {
if err := migrationAddUseForBatchAPIColumnAndS3BucketsConfig(ctx, db); err != nil {
return err
}
+ if err := migrationAddAzureClientIDAndClientSecretAndTenantIDColumns(ctx, db); err != nil {
+ return err
+ }
return nil
}
@@ -1827,3 +1830,45 @@ func migrationAddUseForBatchAPIColumnAndS3BucketsConfig(ctx context.Context, db
}
return nil
}
+
+// migrationAddAzureClientIDAndClientSecretAndTenantIDColumns adds the azure_client_id, azure_client_secret, and azure_tenant_id columns to the key table
+func migrationAddAzureClientIDAndClientSecretAndTenantIDColumns(ctx context.Context, db *gorm.DB) error {
+ m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{
+ ID: "add_azure_client_id_and_client_secret_and_tenant_id_columns",
+ Migrate: func(tx *gorm.DB) error {
+ tx = tx.WithContext(ctx)
+ migrator := tx.Migrator()
+ if !migrator.HasColumn(&tables.TableKey{}, "azure_client_id") {
+ if err := migrator.AddColumn(&tables.TableKey{}, "azure_client_id"); err != nil {
+ return fmt.Errorf("failed to add azure_client_id column: %w", err)
+ }
+ }
+ if !migrator.HasColumn(&tables.TableKey{}, "azure_client_secret") {
+ if err := migrator.AddColumn(&tables.TableKey{}, "azure_client_secret"); err != nil {
+ return fmt.Errorf("failed to add azure_client_secret column: %w", err)
+ }
+ }
+ if !migrator.HasColumn(&tables.TableKey{}, "azure_tenant_id") {
+ if err := migrator.AddColumn(&tables.TableKey{}, "azure_tenant_id"); err != nil {
+ return fmt.Errorf("failed to add azure_tenant_id column: %w", err)
+ }
+ }
+ return nil
+ },
+ Rollback: func(tx *gorm.DB) error {
+ tx = tx.WithContext(ctx)
+ migrator := tx.Migrator()
+ if err := migrator.DropColumn(&tables.TableKey{}, "azure_client_id"); err != nil {
+ return fmt.Errorf("failed to drop azure_client_id column: %w", err)
+ }
+ if err := migrator.DropColumn(&tables.TableKey{}, "azure_client_secret"); err != nil {
+ return fmt.Errorf("failed to drop azure_client_secret column: %w", err)
+ }
+ if err := migrator.DropColumn(&tables.TableKey{}, "azure_tenant_id"); err != nil {
+ return fmt.Errorf("failed to drop azure_tenant_id column: %w", err)
+ }
+ return nil
+ },
+ }})
+ return m.Migrate()
+}
diff --git a/framework/configstore/tables/key.go b/framework/configstore/tables/key.go
index ede86a175..9ac3643cc 100644
--- a/framework/configstore/tables/key.go
+++ b/framework/configstore/tables/key.go
@@ -30,6 +30,9 @@ type TableKey struct {
AzureEndpoint *string `gorm:"type:text" json:"azure_endpoint,omitempty"`
AzureAPIVersion *string `gorm:"type:varchar(50)" json:"azure_api_version,omitempty"`
AzureDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
+ AzureClientID *string `gorm:"type:varchar(255)" json:"azure_client_id,omitempty"`
+ AzureClientSecret *string `gorm:"type:text" json:"azure_client_secret,omitempty"`
+ AzureTenantID *string `gorm:"type:varchar(255)" json:"azure_tenant_id,omitempty"`
// Vertex config fields (embedded)
VertexProjectID *string `gorm:"type:varchar(255)" json:"vertex_project_id,omitempty"`
@@ -39,13 +42,13 @@ type TableKey struct {
VertexDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
// Bedrock config fields (embedded)
- BedrockAccessKey *string `gorm:"type:varchar(255)" json:"bedrock_access_key,omitempty"`
- BedrockSecretKey *string `gorm:"type:text" json:"bedrock_secret_key,omitempty"`
- BedrockSessionToken *string `gorm:"type:text" json:"bedrock_session_token,omitempty"`
- BedrockRegion *string `gorm:"type:varchar(100)" json:"bedrock_region,omitempty"`
- BedrockARN *string `gorm:"type:text" json:"bedrock_arn,omitempty"`
- BedrockDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
- BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config
+ BedrockAccessKey *string `gorm:"type:varchar(255)" json:"bedrock_access_key,omitempty"`
+ BedrockSecretKey *string `gorm:"type:text" json:"bedrock_secret_key,omitempty"`
+ BedrockSessionToken *string `gorm:"type:text" json:"bedrock_session_token,omitempty"`
+ BedrockRegion *string `gorm:"type:varchar(100)" json:"bedrock_region,omitempty"`
+ BedrockARN *string `gorm:"type:text" json:"bedrock_arn,omitempty"`
+ BedrockDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
+ BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config
// Batch API configuration
UseForBatchAPI *bool `gorm:"default:false" json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations
@@ -84,6 +87,9 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error {
k.AzureEndpoint = nil
}
k.AzureAPIVersion = k.AzureKeyConfig.APIVersion
+ k.AzureClientID = k.AzureKeyConfig.ClientID
+ k.AzureClientSecret = k.AzureKeyConfig.ClientSecret
+ k.AzureTenantID = k.AzureKeyConfig.TenantID
if k.AzureKeyConfig.Deployments != nil {
data, err := json.Marshal(k.AzureKeyConfig.Deployments)
if err != nil {
@@ -98,6 +104,9 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error {
k.AzureEndpoint = nil
k.AzureAPIVersion = nil
k.AzureDeploymentsJSON = nil
+ k.AzureClientID = nil
+ k.AzureClientSecret = nil
+ k.AzureTenantID = nil
}
if k.VertexKeyConfig != nil {
@@ -202,8 +211,11 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error {
// Reconstruct Azure config if fields are present
if k.AzureEndpoint != nil {
azureConfig := &schemas.AzureKeyConfig{
- Endpoint: "",
- APIVersion: k.AzureAPIVersion,
+ Endpoint: "",
+ APIVersion: k.AzureAPIVersion,
+ ClientID: k.AzureClientID,
+ ClientSecret: k.AzureClientSecret,
+ TenantID: k.AzureTenantID,
}
if k.AzureEndpoint != nil {
diff --git a/framework/configstore/tables/virtualkey.go b/framework/configstore/tables/virtualkey.go
index ddb7b1fe8..e3d4d438b 100644
--- a/framework/configstore/tables/virtualkey.go
+++ b/framework/configstore/tables/virtualkey.go
@@ -53,6 +53,9 @@ func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error {
// Clear all Azure-related sensitive fields
key.AzureEndpoint = nil
key.AzureAPIVersion = nil
+ key.AzureClientID = nil
+ key.AzureClientSecret = nil
+ key.AzureTenantID = nil
key.AzureDeploymentsJSON = nil
key.AzureKeyConfig = nil
diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go
index 0ebcfe957..63811098d 100644
--- a/transports/bifrost-http/handlers/providers.go
+++ b/transports/bifrost-http/handlers/providers.go
@@ -764,6 +764,31 @@ func (h *ProviderHandler) mergeKeys(provider schemas.ModelProvider, oldRawKeys [
mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion
}
}
+ // handle client id and secret and tenant id
+ if updateKey.AzureKeyConfig.ClientID != nil &&
+ oldRedactedKey.AzureKeyConfig.ClientID != nil &&
+ oldRawKey.AzureKeyConfig != nil {
+ if lib.IsRedacted(*updateKey.AzureKeyConfig.ClientID) &&
+ strings.EqualFold(*updateKey.AzureKeyConfig.ClientID, *oldRedactedKey.AzureKeyConfig.ClientID) {
+ mergedKey.AzureKeyConfig.ClientID = oldRawKey.AzureKeyConfig.ClientID
+ }
+ }
+ if updateKey.AzureKeyConfig.ClientSecret != nil &&
+ oldRedactedKey.AzureKeyConfig.ClientSecret != nil &&
+ oldRawKey.AzureKeyConfig != nil {
+ if lib.IsRedacted(*updateKey.AzureKeyConfig.ClientSecret) &&
+ strings.EqualFold(*updateKey.AzureKeyConfig.ClientSecret, *oldRedactedKey.AzureKeyConfig.ClientSecret) {
+ mergedKey.AzureKeyConfig.ClientSecret = oldRawKey.AzureKeyConfig.ClientSecret
+ }
+ }
+ if updateKey.AzureKeyConfig.TenantID != nil &&
+ oldRedactedKey.AzureKeyConfig.TenantID != nil &&
+ oldRawKey.AzureKeyConfig != nil {
+ if lib.IsRedacted(*updateKey.AzureKeyConfig.TenantID) &&
+ strings.EqualFold(*updateKey.AzureKeyConfig.TenantID, *oldRedactedKey.AzureKeyConfig.TenantID) {
+ mergedKey.AzureKeyConfig.TenantID = oldRawKey.AzureKeyConfig.TenantID
+ }
+ }
}
// Handle Vertex config redacted values
diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go
index b5a0fff8b..d9f786a5c 100644
--- a/transports/bifrost-http/lib/config.go
+++ b/transports/bifrost-http/lib/config.go
@@ -2166,7 +2166,34 @@ func (c *Config) GetProviderConfigRedacted(provider schemas.ModelProvider) (*con
azureConfig.APIVersion = key.AzureKeyConfig.APIVersion
}
}
+ // Redact ClientID if present
+ if key.AzureKeyConfig.ClientID != nil {
+ path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.client_id", provider, key.ID)
+ if envVar, ok := envVarsByPath[path]; ok {
+ azureConfig.ClientID = bifrost.Ptr("env." + envVar)
+ } else {
+ azureConfig.ClientID = bifrost.Ptr(RedactKey(*key.AzureKeyConfig.ClientID))
+ }
+ }
+ // Redact ClientSecret if present
+ if key.AzureKeyConfig.ClientSecret != nil {
+ path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.client_secret", provider, key.ID)
+ if envVar, ok := envVarsByPath[path]; ok {
+ azureConfig.ClientSecret = bifrost.Ptr("env." + envVar)
+ } else if !strings.HasPrefix(*key.AzureKeyConfig.ClientSecret, "env.") {
+ azureConfig.ClientSecret = bifrost.Ptr(RedactKey(*key.AzureKeyConfig.ClientSecret))
+ }
+ }
+ // Redact TenantID if present
+ if key.AzureKeyConfig.TenantID != nil {
+ path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.tenant_id", provider, key.ID)
+ if envVar, ok := envVarsByPath[path]; ok {
+ azureConfig.TenantID = bifrost.Ptr("env." + envVar)
+ } else {
+ azureConfig.TenantID = bifrost.Ptr(RedactKey(*key.AzureKeyConfig.TenantID))
+ }
+ }
redactedConfig.Keys[i].AzureKeyConfig = azureConfig
}
@@ -3264,6 +3291,18 @@ func (c *Config) getFieldValue(key schemas.Key, fieldName string) string {
if key.AzureKeyConfig != nil && key.AzureKeyConfig.APIVersion != nil {
return *key.AzureKeyConfig.APIVersion
}
+ case "client_id":
+ if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientID != nil {
+ return *key.AzureKeyConfig.ClientID
+ }
+ case "client_secret":
+ if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientSecret != nil {
+ return *key.AzureKeyConfig.ClientSecret
+ }
+ case "tenant_id":
+ if key.AzureKeyConfig != nil && key.AzureKeyConfig.TenantID != nil {
+ return *key.AzureKeyConfig.TenantID
+ }
case "access_key":
if key.BedrockKeyConfig != nil {
return key.BedrockKeyConfig.AccessKey
@@ -3397,6 +3436,63 @@ func (c *Config) processAzureKeyConfigEnvVars(key *schemas.Key, provider schemas
azureConfig.APIVersion = &processedAPIVersion
}
+ // Process ClientID if present
+ if azureConfig.ClientID != nil {
+ processedClientID, envVar, err := c.processEnvValue(*azureConfig.ClientID)
+ if err != nil {
+ return err
+ }
+ if envVar != "" {
+ newEnvKeys[envVar] = struct{}{}
+ c.EnvKeys[envVar] = append(c.EnvKeys[envVar], configstore.EnvKeyInfo{
+ EnvVar: envVar,
+ Provider: provider,
+ KeyType: "azure_config",
+ ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.client_id", provider, key.ID),
+ KeyID: key.ID,
+ })
+ }
+ azureConfig.ClientID = &processedClientID
+ }
+
+ // Process ClientSecret if present
+ if azureConfig.ClientSecret != nil {
+ processedClientSecret, envVar, err := c.processEnvValue(*azureConfig.ClientSecret)
+ if err != nil {
+ return err
+ }
+ azureConfig.ClientSecret = &processedClientSecret
+ if envVar != "" {
+ newEnvKeys[envVar] = struct{}{}
+ c.EnvKeys[envVar] = append(c.EnvKeys[envVar], configstore.EnvKeyInfo{
+ EnvVar: envVar,
+ Provider: provider,
+ KeyType: "azure_config",
+ ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.client_secret", provider, key.ID),
+ KeyID: key.ID,
+ })
+ }
+ azureConfig.ClientSecret = &processedClientSecret
+ }
+
+ // Process TenantID if present
+ if azureConfig.TenantID != nil {
+ processedTenantID, envVar, err := c.processEnvValue(*azureConfig.TenantID)
+ if err != nil {
+ return err
+ }
+ if envVar != "" {
+ newEnvKeys[envVar] = struct{}{}
+ c.EnvKeys[envVar] = append(c.EnvKeys[envVar], configstore.EnvKeyInfo{
+ EnvVar: envVar,
+ Provider: provider,
+ KeyType: "azure_config",
+ ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.tenant_id", provider, key.ID),
+ KeyID: key.ID,
+ })
+ }
+ azureConfig.TenantID = &processedTenantID
+ }
return nil
}
diff --git a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx
index db5253bce..1dcf03f4d 100644
--- a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx
+++ b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx
@@ -206,6 +206,54 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) {
)}
/>
+
+
+
+
+ Azure Entra ID Authentication
+
+ To use Azure Entra ID authentication, fill in Client ID, Client Secret, and Tenant ID. Please leave API Key empty when using Entra ID authentication.
+
+
+ (
+
+ Client ID (Optional)
+
+
+
+
+
+ )}
+ />
+ (
+
+ Client Secret (Optional)
+
+
+
+
+
+ )}
+ />
+ (
+
+ Tenant ID (Optional)
+
+
+
+
+
+ )}
+ />
!value || isValidDeployments(value), { message: "Valid Deployments (JSON object) are required for Azure keys" }),
api_version: z.string().optional(),
+ client_id: z.string().optional(),
+ client_secret: z.string().optional(),
+ tenant_id: z.string().optional(),
});
const VertexKeyConfigSchema = z.object({
diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts
index 9d616464a..25053cf32 100644
--- a/ui/lib/types/config.ts
+++ b/ui/lib/types/config.ts
@@ -24,12 +24,18 @@ export interface AzureKeyConfig {
endpoint: string;
deployments?: Record | string; // Allow string during editing
api_version?: string;
+ client_id?: string;
+ client_secret?: string;
+ tenant_id?: string;
}
export const DefaultAzureKeyConfig: AzureKeyConfig = {
endpoint: "",
deployments: {},
api_version: "2024-02-01",
+ client_id: "",
+ client_secret: "",
+ tenant_id: "",
} as const satisfies Required;
// VertexKeyConfig matching Go's schemas.VertexKeyConfig
diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts
index 60e9644fe..d44d59e47 100644
--- a/ui/lib/types/schemas.ts
+++ b/ui/lib/types/schemas.ts
@@ -18,6 +18,9 @@ export const azureKeyConfigSchema = z
endpoint: z.url("Must be a valid URL"),
deployments: z.union([z.record(z.string(), z.string()), z.string()]).optional(),
api_version: z.string().optional(),
+ client_id: z.string().optional(),
+ client_secret: z.string().optional(),
+ tenant_id: z.string().optional(),
})
.refine(
(data) => {