diff --git a/pkg/ccl/logictestccl/testdata/logic_test/provisioning b/pkg/ccl/logictestccl/testdata/logic_test/provisioning index 508a3b90df5e..48122cbdf1c7 100644 --- a/pkg/ccl/logictestccl/testdata/logic_test/provisioning +++ b/pkg/ccl/logictestccl/testdata/logic_test/provisioning @@ -3,7 +3,7 @@ statement error role "root" cannot have a PROVISIONSRC ALTER ROLE root PROVISIONSRC 'ldap:ldap.example.com' -statement error pq: PROVISIONSRC "ldap.example.com" was not prefixed with any valid auth methods \["ldap" "jwt_token"\] +statement error pq: PROVISIONSRC "ldap.example.com" was not prefixed with any valid auth methods \["ldap" "jwt_token" "oidc"\] CREATE ROLE role_with_provisioning PROVISIONSRC 'ldap.example.com' statement error pq: conflicting role options diff --git a/pkg/ccl/oidcccl/BUILD.bazel b/pkg/ccl/oidcccl/BUILD.bazel index be2184896a28..a415060927b4 100644 --- a/pkg/ccl/oidcccl/BUILD.bazel +++ b/pkg/ccl/oidcccl/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "//pkg/ccl/securityccl/jwthelper", "//pkg/ccl/utilccl", "//pkg/roachpb", + "//pkg/security/provisioning", "//pkg/security/username", "//pkg/server", "//pkg/server/authserver", @@ -55,6 +56,7 @@ go_test( "//pkg/ccl", "//pkg/roachpb", "//pkg/security/certnames", + "//pkg/security/provisioning", "//pkg/security/securityassets", "//pkg/security/securitytest", "//pkg/security/username", diff --git a/pkg/ccl/oidcccl/authentication_oidc.go b/pkg/ccl/oidcccl/authentication_oidc.go index 7e6616acc2ff..99d6810154bd 100644 --- a/pkg/ccl/oidcccl/authentication_oidc.go +++ b/pkg/ccl/oidcccl/authentication_oidc.go @@ -19,6 +19,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/ccl/jwtauthccl" "github.com/cockroachdb/cockroach/pkg/ccl/utilccl" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/security/provisioning" secuser "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/authserver" @@ -44,6 +45,7 @@ const ( codeKey = "code" stateKey = "state" secretCookieName = "oidc_secret" + oidcProvisioningKey = "oidc" oidcLoginPath = "/oidc/v1/login" oidcCallbackPath = "/oidc/v1/callback" oidcJWTPath = "/oidc/v1/jwt" @@ -158,6 +160,7 @@ type oidcAuthenticationConf struct { authZEnabled bool groupClaim string userinfoGroupKey string + provisioningEnabled bool } // GetOIDCConf is used to extract certain parts of the OIDC @@ -420,9 +423,10 @@ func reloadConfigLocked( httputil.WithDialerTimeout(clientTimeout), httputil.WithCustomCAPEM(OIDCProviderCustomCA.Get(&st.SV)), ), - authZEnabled: OIDCAuthZEnabled.Get(&st.SV), - groupClaim: OIDCAuthGroupClaim.Get(&st.SV), - userinfoGroupKey: OIDCAuthUserinfoGroupKey.Get(&st.SV), + authZEnabled: OIDCAuthZEnabled.Get(&st.SV), + groupClaim: OIDCAuthGroupClaim.Get(&st.SV), + userinfoGroupKey: OIDCAuthUserinfoGroupKey.Get(&st.SV), + provisioningEnabled: provisioning.ClusterProvisioningConfig(st).Enabled("oidc"), } if !oidcAuthServer.conf.enabled && conf.enabled { @@ -490,6 +494,69 @@ func getRegionSpecificRedirectURL(locality roachpb.Locality, conf redirectURLCon return s, nil } +// maybeProvisionUserLocked checks the cached OIDC configuration to see whether +// automatic user provisioning is enabled. If so, it attempts to create a SQL +// user linked to the OIDC identity provider. +// +// This function is called after a successful OIDC token exchange and +// verification. Its execution relies on the success of the underlying OIDC +// library, which operates on the following assumptions: +// +// 1. OIDC Discovery: The library uses the OIDC discovery protocol to fetch the +// provider's configuration from "/.well-known/openid-configuration". This +// assumes the provider has discovery enabled and accessible. +// +// 2. Issuer Validation: The go-oidc library's verifier ensures the 'iss' claim +// in the ID Token matches the issuer URL from the discovery document. There +// is also an exception made to this in go-oidc for accounts.google.com +// +// Errors during username validation, provisioning source parsing, or user creation +// are logged with detailed context and result in an HTTP 500 response with a +// generic error message to the client. +func maybeProvisionUserLocked( + ctx context.Context, + authConf oidcAuthenticationConf, + execCfg *sql.ExecutorConfig, + username string, + w http.ResponseWriter, +) (err error) { + if !authConf.provisioningEnabled { + return + } + + log.Dev.Infof(ctx, "OIDC: attempting user provisioning for %s", username) + telemetry.Inc(provisioning.BeginOIDCProvisionUseCounter) + + // Convert the extracted username string to a username.SQLUsername type. + sqlUsername, err := secuser.MakeSQLUsernameFromUserInput(username, secuser.PurposeCreation) + if err != nil { + log.Dev.Errorf(ctx, "OIDC provisioning: invalid username format for %s: %v", username, err) + http.Error(w, "OIDC: invalid username format", http.StatusInternalServerError) + return err + } + + // Create the provisioning source identifier string, e.g., "oidc:https://accounts.example.com". + idpString := oidcProvisioningKey + ":" + authConf.providerURL + provisioningSource, err := provisioning.ParseProvisioningSource(idpString) + if err != nil { + // This error occurs if the provisioning package doesn't recognize the "oidc:" prefix. + log.Dev.Errorf(ctx, "OIDC provisioning: error parsing provisioning source IDP %s: %v", idpString, err) + http.Error(w, "OIDC: provisioning error", http.StatusInternalServerError) + return err + } + + // Call the core provisioning function using the execCfg. + if err = sql.CreateRoleForProvisioning(ctx, execCfg, sqlUsername, provisioningSource.String()); err != nil { + log.Dev.Errorf(ctx, "OIDC provisioning: error provisioning user %s: %v", sqlUsername, err) + http.Error(w, "OIDC: provisioning error", http.StatusInternalServerError) + return err + } + + log.Dev.Infof(ctx, "OIDC: successfully provisioned user %s", sqlUsername) + telemetry.Inc(provisioning.ProvisionOIDCSuccessCounter) + return +} + // ConfigureOIDC attaches handlers to the server `mux` that // can initiate and complete an OIDC authentication flow. // This flow consists of an initial login request that triggers @@ -608,6 +675,12 @@ var ConfigureOIDC = func( return } + // OIDC user provisioning + if err := maybeProvisionUserLocked(ctx, oidcAuthentication.conf, oidcAuthentication.execCfg, username, w); err != nil { + log.Dev.Errorf(ctx, "OIDC provisioning failed with error: %v", err) + return + } + // OIDC authorization if err := oidcAuthentication.authorize(ctx, rawIDToken, rawAccessToken, username); err != nil { log.Dev.Errorf(ctx, "OIDC authorization failed with error: %v", err) @@ -938,6 +1011,9 @@ var ConfigureOIDC = func( OIDCAuthUserinfoGroupKey.SetOnChange(&st.SV, func(ctx context.Context) { reloadConfig(ambientCtx.AnnotateCtx(ctx), oidcAuthentication, locality, st) }) + provisioning.OIDCProvisioningEnabled.SetOnChange(&st.SV, func(ctx context.Context) { + reloadConfig(ambientCtx.AnnotateCtx(ctx), oidcAuthentication, locality, st) + }) return oidcAuthentication, nil } diff --git a/pkg/ccl/oidcccl/authentication_oidc_test.go b/pkg/ccl/oidcccl/authentication_oidc_test.go index 75f3b5cc300d..4e0b71175767 100644 --- a/pkg/ccl/oidcccl/authentication_oidc_test.go +++ b/pkg/ccl/oidcccl/authentication_oidc_test.go @@ -27,6 +27,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/certnames" + "github.com/cockroachdb/cockroach/pkg/security/provisioning" "github.com/cockroachdb/cockroach/pkg/security/securityassets" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" @@ -69,8 +70,10 @@ func TestOIDCBadRequestIfDisabled(t *testing.T) { } type mockOidcManager struct { - oauth2Config *oauth2.Config - claimEmail string + oauth2Config *oauth2.Config + claimEmail string + forceIssuerMismatch bool + forceExchangeFailure bool } func (m mockOidcManager) Verify(ctx context.Context, s string) (*oidc.IDToken, error) { @@ -90,9 +93,21 @@ func (m mockOidcManager) AuthCodeURL(s string, option ...oauth2.AuthCodeOption) func (m mockOidcManager) ExchangeVerifyGetTokenInfo( ctx context.Context, code, idTokenKey string, _ bool, ) (map[string]json.RawMessage, map[string]json.RawMessage, string, string, error) { + if m.forceIssuerMismatch { + return nil, nil, "", "", fmt.Errorf("oidc: token issuer mismatch") + } + if m.forceExchangeFailure { + return nil, nil, "", "", fmt.Errorf("token verification failed") + } + + emailClaimJSON, err := json.Marshal(m.claimEmail) + if err != nil { + return nil, nil, "", "", err + } claims := map[string]json.RawMessage{ - "email": json.RawMessage(`"test@example.com"`), + "email": emailClaimJSON, } + // Return nil for access token claims, and the raw token strings. return claims, nil, "dummy-id-token", "dummy-access-token", nil } @@ -725,3 +740,481 @@ func TestOIDCProviderCustomCACert(t *testing.T) { }) } } + +// TestOIDCProvisioning verifies the automatic user provisioning flow. +func TestOIDCProvisioning(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + // Start a full server to get access to the ExecutorConfig, which is + // necessary for the provisioning logic to execute. + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ + Knobs: base.TestingKnobs{ + Server: &server.TestingKnobs{}, + }, + }) + defer s.Stopper().Stop(ctx) + ts := s.ApplicationLayer() + sqlDB := sqlutils.MakeSQLRunner(db) + + usernameUnderTest := "oidcprovisioneduser" + basePath := "/base/path/for/provisioning" + + // Mock the OIDC manager to simulate responses from an Identity Provider. + realNewManager := NewOIDCManager + NewOIDCManager = func(ctx context.Context, conf oidcAuthenticationConf, redirectURL string, scopes []string) (IOIDCManager, error) { + c := &oauth2.Config{ + ClientID: conf.clientID, + ClientSecret: conf.clientSecret, + RedirectURL: redirectURL, + Endpoint: oauth2.Endpoint{ + AuthURL: "https://provider.example.com/endpoint", + }, + Scopes: scopes, + } + // The mockOidcManager will extract `usernameUnderTest` from this email claim. + return &mockOidcManager{oauth2Config: c, claimEmail: fmt.Sprintf("%s@example.com", usernameUnderTest)}, nil + } + defer func() { + NewOIDCManager = realNewManager + }() + + // Configure the necessary OIDC cluster settings for the test. + OIDCProviderURL.Override(ctx, &ts.ClusterSettings().SV, "https://provider.example.com") + OIDCClientID.Override(ctx, &ts.ClusterSettings().SV, "fake_client_id_for_provisioning") + OIDCClientSecret.Override(ctx, &ts.ClusterSettings().SV, "fake_client_secret_for_provisioning") + OIDCRedirectURL.Override(ctx, &ts.ClusterSettings().SV, "https://cockroachlabs.com/oidc/v1/callback") + OIDCClaimJSONKey.Override(ctx, &ts.ClusterSettings().SV, "email") + OIDCPrincipalRegex.Override(ctx, &ts.ClusterSettings().SV, "^([^@]+)@[^@]+$") + server.ServerHTTPBasePath.Override(ctx, &ts.ClusterSettings().SV, basePath) + OIDCEnabled.Override(ctx, &ts.ClusterSettings().SV, true) + + // Setup an HTTP client to make requests to the server. + testCertsContext := ts.NewClientRPCContext(ctx, username.TestUserName()) + client, err := testCertsContext.GetHTTPClient() + require.NoError(t, err) + client.Timeout = 30 * time.Second + + // Sub-test for successful provisioning of a new user. + t.Run("provisioning enabled, new user", func(t *testing.T) { + // Ensure the user does not exist before the test and is dropped after. + sqlDB.Exec(t, fmt.Sprintf("DROP USER IF EXISTS %s", usernameUnderTest)) + defer sqlDB.Exec(t, fmt.Sprintf("DROP USER IF EXISTS %s", usernameUnderTest)) + + // Enable OIDC provisioning. + provisioning.OIDCProvisioningEnabled.Override(ctx, &ts.ClusterSettings().SV, true) + defer provisioning.OIDCProvisioningEnabled.Override(ctx, &ts.ClusterSettings().SV, false) + + // Hit the /login endpoint to get the state token and cookie. + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + resp, err := client.Get(ts.AdminURL().WithPath("/oidc/v1/login").String()) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusFound, resp.StatusCode) + + cookie := resp.Cookies()[0] + require.Equal(t, secretCookieName, cookie.Name) + + authURL, err := url.Parse(resp.Header.Get("Location")) + require.NoError(t, err) + stateParam := authURL.Query().Get("state") + + // Simulate the IdP redirect to the /callback endpoint. + client.CheckRedirect = nil // Allow redirects to follow through to the final page. + req, err := http.NewRequest("GET", ts.AdminURL().WithPath("/oidc/v1/callback").String(), nil) + require.NoError(t, err) + req.AddCookie(cookie) + q := req.URL.Query() + q.Add("state", stateParam) + q.Add("code", "some-auth-code-for-provisioning") + req.URL.RawQuery = q.Encode() + + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Assert a successful login, indicated by a 200 OK status after the redirect. + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, basePath, resp.Request.URL.Path) + + // Verify that the user was created in the database. + rows := sqlDB.Query(t, "SELECT username FROM [SHOW USERS] WHERE username = $1", usernameUnderTest) + require.True(t, rows.Next(), "user should have been created by provisioning") + require.NoError(t, rows.Close()) + }) + + // Sub-test for a successful login for an existing user while provisioning is enabled. + t.Run("provisioning enabled, existing user", func(t *testing.T) { + // Create the user beforehand. + sqlDB.Exec(t, fmt.Sprintf("DROP USER IF EXISTS %s", usernameUnderTest)) + sqlDB.Exec(t, fmt.Sprintf("CREATE USER %s", usernameUnderTest)) + defer sqlDB.Exec(t, fmt.Sprintf("DROP USER IF EXISTS %s", usernameUnderTest)) + + // Enable OIDC provisioning. + provisioning.OIDCProvisioningEnabled.Override(ctx, &ts.ClusterSettings().SV, true) + defer provisioning.OIDCProvisioningEnabled.Override(ctx, &ts.ClusterSettings().SV, false) + + // Simulate the login flow. + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } + resp, err := client.Get(ts.AdminURL().WithPath("/oidc/v1/login").String()) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusFound, resp.StatusCode) + cookie := resp.Cookies()[0] + authURL, err := url.Parse(resp.Header.Get("Location")) + require.NoError(t, err) + stateParam := authURL.Query().Get("state") + + client.CheckRedirect = nil + req, err := http.NewRequest("GET", ts.AdminURL().WithPath("/oidc/v1/callback").String(), nil) + require.NoError(t, err) + req.AddCookie(cookie) + q := req.URL.Query() + q.Add("state", stateParam) + q.Add("code", "some-auth-code-for-existing-user") + req.URL.RawQuery = q.Encode() + + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Assert a successful login. + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify the user still exists. + rows := sqlDB.Query(t, "SELECT username FROM [SHOW USERS] WHERE username = $1", usernameUnderTest) + require.True(t, rows.Next(), "user should still exist") + require.NoError(t, rows.Close()) + }) + + // Sub-test for a successful login for an existing user while provisioning is disabled. + t.Run("provisioning disabled, existing user", func(t *testing.T) { + // Create the user beforehand. + sqlDB.Exec(t, fmt.Sprintf("DROP USER IF EXISTS %s", usernameUnderTest)) + sqlDB.Exec(t, fmt.Sprintf("CREATE USER %s", usernameUnderTest)) + defer sqlDB.Exec(t, fmt.Sprintf("DROP USER IF EXISTS %s", usernameUnderTest)) + + // Disable OIDC provisioning. + provisioning.OIDCProvisioningEnabled.Override(ctx, &ts.ClusterSettings().SV, false) + + // Simulate the login flow. + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } + resp, err := client.Get(ts.AdminURL().WithPath("/oidc/v1/login").String()) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusFound, resp.StatusCode) + cookie := resp.Cookies()[0] + authURL, err := url.Parse(resp.Header.Get("Location")) + require.NoError(t, err) + stateParam := authURL.Query().Get("state") + + client.CheckRedirect = nil + req, err := http.NewRequest("GET", ts.AdminURL().WithPath("/oidc/v1/callback").String(), nil) + require.NoError(t, err) + req.AddCookie(cookie) + q := req.URL.Query() + q.Add("state", stateParam) + q.Add("code", "some-auth-code-for-existing-user-provisioning-disabled") + req.URL.RawQuery = q.Encode() + + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Assert a successful login. + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify the user still exists. + rows := sqlDB.Query(t, "SELECT username FROM [SHOW USERS] WHERE username = $1", usernameUnderTest) + require.True(t, rows.Next(), "user should still exist") + require.NoError(t, rows.Close()) + }) + + // Sub-test to ensure no user is created when provisioning is disabled. + t.Run("provisioning disabled, new user", func(t *testing.T) { + // Ensure the user does not exist. + sqlDB.Exec(t, fmt.Sprintf("DROP USER IF EXISTS %s", usernameUnderTest)) + defer sqlDB.Exec(t, fmt.Sprintf("DROP USER IF EXISTS %s", usernameUnderTest)) + + // Disable OIDC provisioning. + provisioning.OIDCProvisioningEnabled.Override(ctx, &ts.ClusterSettings().SV, false) + + // Simulate the login flow. + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } + resp, err := client.Get(ts.AdminURL().WithPath("/oidc/v1/login").String()) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusFound, resp.StatusCode) + cookie := resp.Cookies()[0] + authURL, err := url.Parse(resp.Header.Get("Location")) + require.NoError(t, err) + stateParam := authURL.Query().Get("state") + + client.CheckRedirect = nil + req, err := http.NewRequest("GET", ts.AdminURL().WithPath("/oidc/v1/callback").String(), nil) + require.NoError(t, err) + req.AddCookie(cookie) + q := req.URL.Query() + q.Add("state", stateParam) + q.Add("code", "some-auth-code-for-disabled-provisioning") + req.URL.RawQuery = q.Encode() + + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Assert a failed login because the user does not exist and will not be created. + require.Equal(t, http.StatusForbidden, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), genericCallbackHTTPError) + + // Verify the user was NOT created. + rows := sqlDB.Query(t, "SELECT username FROM [SHOW USERS] WHERE username = $1", usernameUnderTest) + require.False(t, rows.Next(), "user should not have been created") + require.NoError(t, rows.Close()) + }) +} + +// TestOIDCExchangeVerifyFailure ensures that a verification error inside +// ExchangeVerifyGetTokenInfo surfaces as an HTTP 500 at /oidc/v1/callback. +func TestOIDCExchangeVerifyFailure(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + s := srv.ApplicationLayer() + + usernameUnderTest := "verifyfailuser" + basePath := "/some/oidc/path" + + // Intercept NewOIDCManager and return the failing mock. + realNewManager := NewOIDCManager + NewOIDCManager = func( + ctx context.Context, + conf oidcAuthenticationConf, + redirectURL string, + scopes []string, + ) (IOIDCManager, error) { + c := &oauth2.Config{ + ClientID: conf.clientID, + ClientSecret: conf.clientSecret, + RedirectURL: redirectURL, + Endpoint: oauth2.Endpoint{AuthURL: "https://provider.example.com/endpoint"}, + Scopes: scopes, + } + return &mockOidcManager{ + oauth2Config: c, + claimEmail: fmt.Sprintf("%s@example.com", usernameUnderTest), + forceExchangeFailure: true, + }, nil + } + defer func() { NewOIDCManager = realNewManager }() + + // Minimal cluster‑setting wiring to enable OIDC. + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, fmt.Sprintf(`CREATE USER %s WITH PASSWORD 'placeholder'`, usernameUnderTest)) + + OIDCProviderURL.Override(ctx, &s.ClusterSettings().SV, "providerURL") + OIDCClientID.Override(ctx, &s.ClusterSettings().SV, "fake_client_id") + OIDCClientSecret.Override(ctx, &s.ClusterSettings().SV, "fake_client_secret") + OIDCRedirectURL.Override(ctx, &s.ClusterSettings().SV, "https://cockroachlabs.com/oidc/v1/callback") + OIDCClaimJSONKey.Override(ctx, &s.ClusterSettings().SV, "email") + OIDCPrincipalRegex.Override(ctx, &s.ClusterSettings().SV, "^([^@]+)@[^@]+$") + server.ServerHTTPBasePath.Override(ctx, &s.ClusterSettings().SV, basePath) + OIDCEnabled.Override(ctx, &s.ClusterSettings().SV, true) + + testCtx := s.NewClientRPCContext(ctx, username.TestUserName()) + client, err := testCtx.GetHTTPClient() + require.NoError(t, err) + client.Timeout = 30 * time.Second + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } + + // hit /login to get state & cookie. + loginResp, err := client.Get(s.AdminURL().WithPath("/oidc/v1/login").String()) + require.NoError(t, err) + defer loginResp.Body.Close() + require.Equal(t, http.StatusFound, loginResp.StatusCode) + + cookie := loginResp.Cookies()[0] + authURL, err := url.Parse(loginResp.Header.Get("Location")) + require.NoError(t, err) + stateParam := authURL.Query().Get("state") + + // simulate IdP redirect to /callback with same state. + req, err := http.NewRequest("GET", s.AdminURL().WithPath("/oidc/v1/callback").String(), nil) + require.NoError(t, err) + req.AddCookie(cookie) + q := req.URL.Query() + q.Add("state", stateParam) + q.Add("code", "irrelevant-auth-code") + req.URL.RawQuery = q.Encode() + + callbackResp, err := client.Do(req) + require.NoError(t, err) + defer callbackResp.Body.Close() + + // Verification failure in the manager should surface as 500 + generic error. + require.Equal(t, http.StatusInternalServerError, callbackResp.StatusCode) + body, err := io.ReadAll(callbackResp.Body) + require.NoError(t, err) + require.Contains(t, string(body), genericCallbackHTTPError) +} + +func TestOIDCIssuerValidation(t *testing.T) { + // Sub-test 1: This test ensures that an ID token with an untrusted, + // malicious, or misconfigured issuer (`iss` claim) is rejected. + t.Run("fails when token's issuer differs from the configured provider", func(t *testing.T) { + + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv := serverutils.StartServerOnly(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + s := srv.ApplicationLayer() + + // Replace NewOIDCManager with manager that errors on Verify. + restore := testutils.TestingHook( + &NewOIDCManager, + func(ctx context.Context, c oidcAuthenticationConf, redirectURL string, scopes []string) (IOIDCManager, error) { + conf := &oauth2.Config{ + ClientID: c.clientID, + RedirectURL: redirectURL, + Endpoint: oauth2.Endpoint{AuthURL: c.providerURL}, + Scopes: scopes, + } + return &mockOidcManager{oauth2Config: conf, forceIssuerMismatch: true}, nil + }, + ) + defer restore() + + // Minimal cluster settings to enable OIDC. + OIDCProviderURL.Override(ctx, &s.ClusterSettings().SV, "https://provider.example.com") + OIDCClientID.Override(ctx, &s.ClusterSettings().SV, "cid") + OIDCClientSecret.Override(ctx, &s.ClusterSettings().SV, "sec") + OIDCRedirectURL.Override(ctx, &s.ClusterSettings().SV, "https://cockroachlabs.com/oidc/v1/callback") + OIDCEnabled.Override(ctx, &s.ClusterSettings().SV, true) + + httpClient, err := s.NewClientRPCContext(ctx, username.TestUserName()).GetHTTPClient() + require.NoError(t, err) + httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse } + + // Start login (302 expected). + loginResp, err := httpClient.Get(s.AdminURL().WithPath("/oidc/v1/login").String()) + require.NoError(t, err) + defer loginResp.Body.Close() + require.Equal(t, http.StatusFound, loginResp.StatusCode) + + // Extract state from the login redirect to use in the callback. + authURL, err := url.Parse(loginResp.Header.Get("Location")) + require.NoError(t, err) + stateParam := authURL.Query().Get("state") + + // Simulate callback; expect 500 because the mock manager will fail verification. + cbReq, _ := http.NewRequest("GET", s.AdminURL().WithPath("/oidc/v1/callback").String(), nil) + // Use the extracted stateParam + q := cbReq.URL.Query() + q.Add("state", stateParam) + q.Add("code", "bad-code") + cbReq.URL.RawQuery = q.Encode() + + for _, c := range loginResp.Cookies() { + cbReq.AddCookie(c) + } + cbResp, err := httpClient.Do(cbReq) + require.NoError(t, err) + defer cbResp.Body.Close() + require.Equal(t, http.StatusInternalServerError, cbResp.StatusCode) + + // Validate that the response body contains the generic error message. + body, err := io.ReadAll(cbResp.Body) + require.NoError(t, err) + require.Contains(t, string(body), genericCallbackHTTPError) + }) + + // Sub‑test 2: provider is not issuer; external issuer must be rejected. + t.Run("external issuer rejected when provider not issuer", func(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + s := srv.ApplicationLayer() + + const usernameUnderTest = "externissuer" + sqlutils.MakeSQLRunner(db).Exec(t, fmt.Sprintf(`CREATE USER "%s"`, usernameUnderTest)) + + // Inject manager that returns a token signed by an issuer != provider. + restore := testutils.TestingHook( + &NewOIDCManager, + func(ctx context.Context, c oidcAuthenticationConf, redirectURL string, scopes []string) (IOIDCManager, error) { + conf := &oauth2.Config{ + ClientID: c.clientID, + RedirectURL: redirectURL, + Endpoint: oauth2.Endpoint{AuthURL: c.providerURL}, + Scopes: scopes, + } + + return &mockOidcManager{ + oauth2Config: conf, + claimEmail: fmt.Sprintf("%s@example.com", usernameUnderTest), + forceIssuerMismatch: true, + }, nil + }, + ) + defer restore() + + // Enable OIDC with a provider that differs from the token's issuer. + OIDCProviderURL.Override(ctx, &s.ClusterSettings().SV, "https://proxy-provider.example.com") + OIDCClientID.Override(ctx, &s.ClusterSettings().SV, "cid") + OIDCClientSecret.Override(ctx, &s.ClusterSettings().SV, "sec") + OIDCRedirectURL.Override(ctx, &s.ClusterSettings().SV, "https://cockroachlabs.com/oidc/v1/callback") + OIDCClaimJSONKey.Override(ctx, &s.ClusterSettings().SV, "email") + OIDCPrincipalRegex.Override(ctx, &s.ClusterSettings().SV, "^([^@]+)@[^@]+$") + OIDCEnabled.Override(ctx, &s.ClusterSettings().SV, true) + + httpClient, err := s.NewClientRPCContext(ctx, username.TestUserName()).GetHTTPClient() + require.NoError(t, err) + httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse } + + // Begin auth flow (expect 302). + loginResp, err := httpClient.Get(s.AdminURL().WithPath("/oidc/v1/login").String()) + require.NoError(t, err) + defer loginResp.Body.Close() + require.Equal(t, http.StatusFound, loginResp.StatusCode) + + // Extract the real state parameter from the login redirect. + authURL, err := url.Parse(loginResp.Header.Get("Location")) + require.NoError(t, err) + stateParam := authURL.Query().Get("state") + + // Complete callback; CockroachDB must reject the mismatched issuer -> 500. + cbReq, _ := http.NewRequest("GET", s.AdminURL().WithPath("/oidc/v1/callback").String(), nil) + + // Use the extracted stateParam + q := cbReq.URL.Query() + q.Add("state", stateParam) + q.Add("code", "good-code") + cbReq.URL.RawQuery = q.Encode() + + for _, c := range loginResp.Cookies() { + cbReq.AddCookie(c) + } + cbResp, err := httpClient.Do(cbReq) + require.NoError(t, err) + defer cbResp.Body.Close() + require.Equal(t, http.StatusInternalServerError, cbResp.StatusCode) + + // Validate that the response body contains the generic error message. + body, err := io.ReadAll(cbResp.Body) + require.NoError(t, err) + require.Contains(t, string(body), genericCallbackHTTPError) + }) +} diff --git a/pkg/security/provisioning/provisioning_source.go b/pkg/security/provisioning/provisioning_source.go index 55790f3ecb00..175a23dcbf40 100644 --- a/pkg/security/provisioning/provisioning_source.go +++ b/pkg/security/provisioning/provisioning_source.go @@ -51,7 +51,7 @@ func ParseProvisioningSource(sourceStr string) (*Source, error) { } func parseAuthMethod(sourceStr string) (authMethod string, idp string, err error) { - supportedProvisioningMethods := []string{supportedAuthMethodLDAP, supportedAuthMethodJWT} + supportedProvisioningMethods := []string{supportedAuthMethodLDAP, supportedAuthMethodJWT, supportedAuthMethodOIDC} for _, method := range supportedProvisioningMethods { prefix := method + ":" if strings.HasPrefix(sourceStr, prefix) { diff --git a/pkg/security/provisioning/provisioning_source_test.go b/pkg/security/provisioning/provisioning_source_test.go index 4df1ef3027c6..b34af83b7bf1 100644 --- a/pkg/security/provisioning/provisioning_source_test.go +++ b/pkg/security/provisioning/provisioning_source_test.go @@ -55,7 +55,7 @@ func TestParseProvisioningSource(t *testing.T) { }, } - for _, method := range []string{"ldap", "jwt_token"} { + for _, method := range []string{"ldap", "jwt_token", "oidc"} { t.Run(method, func(t *testing.T) { for _, tt := range sharedTests { t.Run(tt.name, func(t *testing.T) { @@ -86,6 +86,15 @@ func TestParseProvisioningSource(t *testing.T) { require.Equal(t, "https://accounts.google.com", source.idp.String()) }) + // Test case specific to oidc allowing https + t.Run("oidc/valid_https_source", func(t *testing.T) { + source, err := ParseProvisioningSource("oidc:https://accounts.google.com") + require.NoError(t, err) + require.NotNil(t, source) + require.Equal(t, "oidc", source.authMethod) + require.Equal(t, "https://accounts.google.com", source.idp.String()) + }) + // According to the current implementation of `parseIDP`, this is valid. // The function `url.Parse` accepts this and the checks for Port and Opaque pass. t.Run("ldap/https_source_is_valid_by_current_implementation", func(t *testing.T) { @@ -159,7 +168,7 @@ func TestValidateSource(t *testing.T) { }, } - for _, method := range []string{"ldap", "jwt_token"} { + for _, method := range []string{"ldap", "jwt_token", "oidc"} { t.Run(method, func(t *testing.T) { for _, tt := range sharedTests { t.Run(tt.name, func(t *testing.T) { @@ -188,6 +197,11 @@ func TestValidateSource(t *testing.T) { require.NoError(t, err) }) + t.Run("oidc/valid_https_source", func(t *testing.T) { + err := ValidateSource("oidc:https://accounts.google.com") + require.NoError(t, err) + }) + // Global failure cases t.Run("global_failures", func(t *testing.T) { globalTests := []struct { diff --git a/pkg/security/provisioning/settings.go b/pkg/security/provisioning/settings.go index 46f6186c5c4c..2eabb74f0516 100644 --- a/pkg/security/provisioning/settings.go +++ b/pkg/security/provisioning/settings.go @@ -15,16 +15,29 @@ import ( // All cluster settings necessary for the provisioning feature. const ( - supportedAuthMethodLDAP = "ldap" + supportedAuthMethodLDAP = "ldap" + // Although "oidc" is not a valid auth method in the AuthMethod + // factory, this is added so provisioning remains consistent with + // LDAP/JWT and can be referenced via the PROVISIONSRC role option. + // + // TODO(souravcrl): ensure proper setting of the role option for + // PROVISIONSRC incase both JWT and OIDC provisioning are + // enabled. While users can be provisioned post authentication from + // both auth methods, since the underlying user identity is the JWT + // token, we can end up with divergent user cohorts due to auth + // method set for the option value. + supportedAuthMethodOIDC = "oidc" testSupportedAuthMethodCertPassword = "cert-password" supportedAuthMethodJWT = "jwt_token" baseProvisioningSettingName = "security.provisioning." ldapProvisioningEnableSettingName = baseProvisioningSettingName + "ldap.enabled" jwtProvisioningEnableSettingName = baseProvisioningSettingName + "jwt.enabled" + oidcProvisioningEnableSettingName = baseProvisioningSettingName + "oidc.enabled" baseCounterPrefix = "auth.provisioning." ldapCounterPrefix = baseCounterPrefix + "ldap." jwtCounterPrefix = baseCounterPrefix + "jwt." + oidcCounterPrefix = baseCounterPrefix + "oidc." beginLDAPProvisionCounterName = ldapCounterPrefix + "begin" provisionLDAPSuccessCounterName = ldapCounterPrefix + "success" @@ -34,6 +47,10 @@ const ( provisionJWTSuccessCounterName = jwtCounterPrefix + "success" enableJWTProvisionCounterName = jwtCounterPrefix + "enable" + beginOIDCProvisionCounterName = oidcCounterPrefix + "begin" + provisionOIDCSuccessCounterName = oidcCounterPrefix + "success" + enableOIDCProvisionCounterName = oidcCounterPrefix + "enable" + provisionedUserLoginSuccessCounterName = baseCounterPrefix + "login_success" ) @@ -46,6 +63,10 @@ var ( ProvisionJWTSuccessCounter = telemetry.GetCounterOnce(provisionJWTSuccessCounterName) enableJWTProvisionCounter = telemetry.GetCounterOnce(enableJWTProvisionCounterName) + BeginOIDCProvisionUseCounter = telemetry.GetCounterOnce(beginOIDCProvisionCounterName) + ProvisionOIDCSuccessCounter = telemetry.GetCounterOnce(provisionOIDCSuccessCounterName) + enableOIDCProvisionCounter = telemetry.GetCounterOnce(enableOIDCProvisionCounterName) + ProvisionedUserLoginSuccessCounter = telemetry.GetCounterOnce(provisionedUserLoginSuccessCounterName) ) @@ -76,6 +97,15 @@ var jwtProvisioningEnabled = settings.RegisterBoolSetting( false, ) +// OIDCProvisioningEnabled enables automatic user provisioning for DB Console OIDC +// authentication method. +var OIDCProvisioningEnabled = settings.RegisterBoolSetting( + settings.ApplicationLevel, + oidcProvisioningEnableSettingName, + "enables or disables automatic user provisioning for oidc authentication method", + false, +) + type clusterProvisioningConfig struct { settings *cluster.Settings } @@ -91,6 +121,8 @@ func (c clusterProvisioningConfig) Enabled(authMethod string) bool { switch authMethod { case supportedAuthMethodLDAP: return ldapProvisioningEnabled.Get(&c.settings.SV) + case supportedAuthMethodOIDC: + return OIDCProvisioningEnabled.Get(&c.settings.SV) case testSupportedAuthMethodCertPassword: return Testing.Supported case supportedAuthMethodJWT: @@ -114,5 +146,10 @@ func ClusterProvisioningConfig(settings *cluster.Settings) UserProvisioningConfi telemetry.Inc(enableJWTProvisionCounter) } }) + OIDCProvisioningEnabled.SetOnChange(&settings.SV, func(_ context.Context) { + if OIDCProvisioningEnabled.Get(&settings.SV) { + telemetry.Inc(enableOIDCProvisionCounter) + } + }) return clusterProvisioningConfig{settings: settings} }