diff --git a/go.mod b/go.mod index e8ff713e..3c999865 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,10 @@ go 1.25.7 require ( github.com/99designs/keyring v1.2.2 github.com/AlecAivazis/survey/v2 v2.3.7 - github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2 v1.41.2 github.com/aws/aws-sdk-go-v2/config v1.27.11 github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/pkg/errors v0.9.1 @@ -62,12 +62,12 @@ require ( github.com/BurntSushi/toml v1.3.2 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect github.com/aws/aws-sdk-go-v2/service/iam v1.28.7 github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect - github.com/aws/smithy-go v1.20.2 + github.com/aws/smithy-go v1.24.1 github.com/common-fate/awsconfigfile v0.10.0 github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/danieljoos/wincred v1.1.2 // indirect diff --git a/go.sum b/go.sum index 7236b3d2..3751100a 100644 --- a/go.sum +++ b/go.sum @@ -16,18 +16,18 @@ github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63n github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4uEoM0= github.com/alessio/shellescape v1.4.2/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= -github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= -github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= +github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= github.com/aws/aws-sdk-go-v2/config v1.27.11 h1:f47rANd2LQEYHda2ddSCKYId18/8BhSRM4BULGmfgNA= github.com/aws/aws-sdk-go-v2/config v1.27.11/go.mod h1:SMsV78RIOYdve1vf36z8LmnszlRWkwMQtomCAI0/mIE= github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= github.com/aws/aws-sdk-go-v2/service/iam v1.28.7 h1:FKPRDYZOO0Eur19vWUL1B40Op0j89KQj3kARjrszMK8= @@ -38,12 +38,12 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/g github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w= github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWAXLGFIizeqkdkKgRlJwWc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU= github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 h1:cwIxeBttqPN3qkaAjcEcsh8NYr8n2HZPkcKgPAi1phU= github.com/aws/aws-sdk-go-v2/service/sts v1.28.6/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= -github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= -github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= +github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/common-fate/awsconfigfile v0.10.0 h1:9W0JTeO0d3jNLw3Ps9U7IJwLYp4D9zcipq/sqNEWJOg= @@ -63,8 +63,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dvsekhvalnov/jose2go v1.6.0 h1:Y9gnSnP4qEI0+/uQkHvFXeD2PLPJeXEL+ySMEA2EjTY= -github.com/dvsekhvalnov/jose2go v1.6.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/dvsekhvalnov/jose2go v1.8.0 h1:LqkkVKAlHFfH9LOEl5fe4p/zL02OhWE7pCufMBG2jLA= github.com/dvsekhvalnov/jose2go v1.8.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= diff --git a/pkg/cfaws/assumer_aws_sso.go b/pkg/cfaws/assumer_aws_sso.go index 35a065aa..d3fe645c 100644 --- a/pkg/cfaws/assumer_aws_sso.go +++ b/pkg/cfaws/assumer_aws_sso.go @@ -192,7 +192,19 @@ func (c *Profile) SSOLogin(ctx context.Context, configOpts ConfigOpts) (aws.Cred if cachedToken == nil && plainTextToken == nil { newCfg := aws.NewConfig() newCfg.Region = rootProfile.SSORegion() - newSSOToken, err := idclogin.Login(ctx, *newCfg, rootProfile.SSOStartURL(), rootProfile.SSOScopes()) + + var newSSOToken *securestorage.SSOToken + var err error + + // Use authorization code flow with PKCE when an sso_session is configured, + // unless the user has explicitly requested device code flow or we're + // in a headless environment where the localhost redirect won't work. + useDeviceCode := configOpts.UseDeviceCode || idclogin.IsHeadlessEnvironment() + if c.AWSConfig.SSOSessionName != "" && !useDeviceCode { + newSSOToken, err = idclogin.LoginWithAuthorizationCode(ctx, *newCfg, rootProfile.SSOStartURL(), rootProfile.SSOScopes()) + } else { + newSSOToken, err = idclogin.Login(ctx, *newCfg, rootProfile.SSOStartURL(), rootProfile.SSOScopes()) + } if err != nil { return aws.Credentials{}, err } diff --git a/pkg/cfaws/profiles.go b/pkg/cfaws/profiles.go index 140db9e9..baab7afb 100644 --- a/pkg/cfaws/profiles.go +++ b/pkg/cfaws/profiles.go @@ -27,6 +27,7 @@ type ConfigOpts struct { ShouldRetryAssuming *bool MFATokenCode string DisableCache bool + UseDeviceCode bool } type Profile struct { diff --git a/pkg/granted/sso.go b/pkg/granted/sso.go index 0af49e28..4c064b29 100644 --- a/pkg/granted/sso.go +++ b/pkg/granted/sso.go @@ -245,6 +245,7 @@ var LoginCommand = cli.Command{ &cli.StringFlag{Name: "sso-region", Usage: "Specify the SSO region"}, &cli.StringFlag{Name: "sso-start-url", Usage: "Specify the SSO start url"}, &cli.StringSliceFlag{Name: "sso-scope", Usage: "Specify the SSO scopes"}, + &cli.BoolFlag{Name: "use-device-code", Usage: "Use device code flow instead of authorization code with PKCE"}, }, Action: func(c *cli.Context) error { ctx := c.Context @@ -297,7 +298,15 @@ var LoginCommand = cli.Command{ secureSSOTokenStorage := securestorage.NewSecureSSOTokenStorage() - newSSOToken, err := idclogin.Login(ctx, *cfg, ssoStartUrl, ssoScopes) + var newSSOToken *securestorage.SSOToken + var err error + + useDeviceCode := c.Bool("use-device-code") || idclogin.IsHeadlessEnvironment() + if useDeviceCode { + newSSOToken, err = idclogin.Login(ctx, *cfg, ssoStartUrl, ssoScopes) + } else { + newSSOToken, err = idclogin.LoginWithAuthorizationCode(ctx, *cfg, ssoStartUrl, ssoScopes) + } if err != nil { return err } @@ -347,8 +356,12 @@ func (s AWSSSOSource) GetProfiles(ctx context.Context) ([]awsconfigfile.SSOProfi } if ssoTokenFromSecureCache == nil && ssoTokenFromPlainText == nil { - // otherwise, login with SSO - ssoTokenFromSecureCache, err = idclogin.Login(ctx, cfg, s.StartURL, s.SSOScopes) + // Login with SSO, using authorization code flow by default unless headless + if idclogin.IsHeadlessEnvironment() { + ssoTokenFromSecureCache, err = idclogin.Login(ctx, cfg, s.StartURL, s.SSOScopes) + } else { + ssoTokenFromSecureCache, err = idclogin.LoginWithAuthorizationCode(ctx, cfg, s.StartURL, s.SSOScopes) + } if err != nil { return nil, err } diff --git a/pkg/idclogin/authcode.go b/pkg/idclogin/authcode.go new file mode 100644 index 00000000..e5705a5d --- /dev/null +++ b/pkg/idclogin/authcode.go @@ -0,0 +1,304 @@ +package idclogin + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "html/template" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssooidc" + "github.com/common-fate/clio" + "github.com/fwdcloudsec/granted/pkg/securestorage" + "github.com/google/uuid" +) + +// authorizationCallbackTimeout is the maximum time to wait for the user to +// complete browser-based authentication before giving up. +const authorizationCallbackTimeout = 5 * time.Minute + +// LoginWithAuthorizationCode performs an Authorization Code Grant with PKCE flow +// to retrieve an SSO token. This provides a smoother UX than the device code flow +// by skipping the manual code entry step. +func LoginWithAuthorizationCode(ctx context.Context, cfg aws.Config, startUrl string, scopes []string) (*securestorage.SSOToken, error) { + if cfg.Region == "" { + return nil, errors.New("AWS region is required for authorization code flow") + } + + ssooidcClient := ssooidc.NewFromConfig(cfg) + + // The authorization code flow uses "sso:account:access" as the default scope, + // which is the modern scope for IAM Identity Center. This differs from the + // device code flow's legacy "sso-portal:*" default. + if len(scopes) == 0 { + scopes = []string{"sso:account:access"} + } + + // Bind the listener first to reserve a port for the redirect URI. + // We defer starting the HTTP server until after RegisterClient succeeds. + callbackResult := make(chan authCallbackResult, 1) + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("failed to start local OAuth callback server: %w", err) + } + port := listener.Addr().(*net.TCPAddr).Port + redirectURI := fmt.Sprintf("http://127.0.0.1:%d/oauth/callback", port) + + state := uuid.New().String() + + // Register client with authorization_code grant type. + // + // The redirect URI for registration uses the portless form. Per RFC 8252 + // Section 7.3, authorization servers MUST allow any port to be specified for + // loopback redirect URIs. AWS IAM Identity Center implements this exemption: + // the portless URI is registered, but the actual redirect uses a port-specific URI. + client, err := ssooidcClient.RegisterClient(ctx, &ssooidc.RegisterClientInput{ + ClientName: aws.String("Granted CLI"), + ClientType: aws.String("public"), + GrantTypes: []string{"authorization_code", "refresh_token"}, + RedirectUris: []string{"http://127.0.0.1/oauth/callback"}, + IssuerUrl: aws.String(startUrl), + Scopes: scopes, + }) + if err != nil { + _ = listener.Close() + return nil, fmt.Errorf("failed to register OIDC client: %w", err) + } + + // Now that registration succeeded, start the HTTP server to receive the callback. + srv := &http.Server{ + Handler: newCallbackHandler(state, callbackResult), + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 30 * time.Second, + } + go func() { + if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + clio.Debugf("OAuth callback server error: %s", err) + } + }() + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = srv.Shutdown(shutdownCtx) + }() + + // Determine the authorization endpoint. The RegisterClient API may return it, + // but many regions don't include it in the response. Fall back to the standard + // regional endpoint pattern used by the AWS CLI. + authorizationEndpoint := fmt.Sprintf("https://oidc.%s.amazonaws.com/authorize", cfg.Region) + if client.AuthorizationEndpoint != nil && *client.AuthorizationEndpoint != "" { + authorizationEndpoint = *client.AuthorizationEndpoint + } + + // Generate PKCE code verifier and challenge + codeVerifier, err := generateCodeVerifier() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE code verifier: %w", err) + } + codeChallenge := computeCodeChallenge(codeVerifier) + + // Construct the authorization URL + authorizeURL, err := buildAuthorizeURL(authorizationEndpoint, *client.ClientId, redirectURI, state, codeChallenge, scopes) + if err != nil { + return nil, fmt.Errorf("failed to build authorize URL: %w", err) + } + + // Open browser with fallback message + if err := OpenBrowserWithFallbackMessage(authorizeURL); err != nil { + return nil, err + } + + clio.Info("Awaiting AWS authentication in the browser") + clio.Info("You will be prompted to authenticate and approve access") + + // Wait for the callback + var result authCallbackResult + select { + case result = <-callbackResult: + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(authorizationCallbackTimeout): + return nil, errors.New("timed out waiting for authorization callback") + } + + if result.err != nil { + return nil, fmt.Errorf("authorization failed: %w", result.err) + } + + // Exchange the authorization code for tokens + token, err := ssooidcClient.CreateToken(ctx, &ssooidc.CreateTokenInput{ + ClientId: client.ClientId, + ClientSecret: client.ClientSecret, + GrantType: aws.String("authorization_code"), + Code: aws.String(result.code), + CodeVerifier: aws.String(codeVerifier), + RedirectUri: aws.String(redirectURI), + }) + if err != nil { + return nil, fmt.Errorf("failed to exchange authorization code for token: %w", err) + } + + ssoToken := securestorage.SSOToken{ + AccessToken: *token.AccessToken, + Expiry: time.Now().Add(time.Duration(token.ExpiresIn) * time.Second), + ClientID: *client.ClientId, + ClientSecret: *client.ClientSecret, + RegistrationExpiresAt: time.Unix(client.ClientSecretExpiresAt, 0), + RefreshToken: token.RefreshToken, + Region: cfg.Region, + } + + return &ssoToken, nil +} + +type authCallbackResult struct { + code string + err error +} + +type callbackPageData struct { + Error string + Description string +} + +var callbackErrorTmpl = template.Must(template.New("error").Parse(` + +Granted - Authentication Failed + +
+

Authentication Failed

+

Error: {{.Error}}

+

{{.Description}}

+

Please close this window and try again.

+
+ +`)) + +const callbackSuccessHTML = ` + +Granted - Authentication Successful + +
+

Authentication Successful

+

You have successfully authenticated with AWS IAM Identity Center.

+

You can close this window and return to your terminal.

+
+ +` + +// setSecurityHeaders sets defensive HTTP headers on callback responses. +func setSecurityHeaders(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline'") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") +} + +func writeErrorPage(w http.ResponseWriter, errCode, description string) { + setSecurityHeaders(w) + w.WriteHeader(http.StatusBadRequest) + _ = callbackErrorTmpl.Execute(w, callbackPageData{ + Error: errCode, + Description: description, + }) +} + +func newCallbackHandler(expectedState string, result chan<- authCallbackResult) http.Handler { + var once sync.Once + mux := http.NewServeMux() + mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { + // Only accept GET requests per OAuth 2.0 spec + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Only process the first callback request. Subsequent requests + // (browser retries, favicon fetches, attacker probes) are ignored. + var handled bool + once.Do(func() { + handled = true + query := r.URL.Query() + + if errParam := query.Get("error"); errParam != "" { + errDesc := query.Get("error_description") + writeErrorPage(w, errParam, errDesc) + result <- authCallbackResult{err: fmt.Errorf("%s: %s", errParam, errDesc)} + return + } + + code := query.Get("code") + state := query.Get("state") + + if state != expectedState { + writeErrorPage(w, "state_mismatch", "The state parameter did not match. This may indicate a CSRF attack.") + result <- authCallbackResult{err: errors.New("OAuth state parameter mismatch")} + return + } + + if code == "" { + writeErrorPage(w, "missing_code", "No authorization code was received.") + result <- authCallbackResult{err: errors.New("no authorization code received")} + return + } + + setSecurityHeaders(w) + w.WriteHeader(http.StatusOK) + // callbackSuccessHTML is a static string with no interpolation, safe to write directly + _, _ = w.Write([]byte(callbackSuccessHTML)) + result <- authCallbackResult{code: code} + }) + + if !handled { + http.Error(w, "Authorization already processed", http.StatusConflict) + } + }) + return mux +} + +// generateCodeVerifier generates a cryptographically random code verifier +// per RFC 7636. It produces a 43-128 character string from the unreserved character set. +func generateCodeVerifier() (string, error) { + // 32 bytes -> 43 base64url characters (no padding) + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// computeCodeChallenge computes the S256 code challenge from the verifier per RFC 7636. +func computeCodeChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +func buildAuthorizeURL(authorizationEndpoint, clientID, redirectURI, state, codeChallenge string, scopes []string) (string, error) { + u, err := url.Parse(authorizationEndpoint) + if err != nil { + return "", err + } + + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", clientID) + q.Set("redirect_uri", redirectURI) + q.Set("state", state) + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + q.Set("scope", strings.Join(scopes, " ")) + u.RawQuery = q.Encode() + + return u.String(), nil +} diff --git a/pkg/idclogin/authcode_test.go b/pkg/idclogin/authcode_test.go new file mode 100644 index 00000000..6d775b46 --- /dev/null +++ b/pkg/idclogin/authcode_test.go @@ -0,0 +1,194 @@ +package idclogin + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateCodeVerifier(t *testing.T) { + verifier, err := generateCodeVerifier() + require.NoError(t, err) + + // RFC 7636: code verifier must be 43-128 characters + assert.GreaterOrEqual(t, len(verifier), 43) + assert.LessOrEqual(t, len(verifier), 128) + + // Should be base64url encoded (no padding) + assert.NotContains(t, verifier, "=") + assert.NotContains(t, verifier, "+") + assert.NotContains(t, verifier, "/") + + // Two verifiers should be different (randomness check) + verifier2, err := generateCodeVerifier() + require.NoError(t, err) + assert.NotEqual(t, verifier, verifier2) +} + +func TestComputeCodeChallenge(t *testing.T) { + // RFC 7636 Appendix B known test vector + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + challenge := computeCodeChallenge(verifier) + + // Expected value from RFC 7636 Appendix B + assert.Equal(t, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", challenge) + + // Challenge should not contain padding + assert.NotContains(t, challenge, "=") +} + +func TestBuildAuthorizeURL(t *testing.T) { + url, err := buildAuthorizeURL( + "https://oidc.us-west-2.amazonaws.com/authorize", + "client-123", + "http://127.0.0.1:12345/oauth/callback", + "state-uuid", + "challenge-value", + []string{"sso:account:access"}, + ) + require.NoError(t, err) + + assert.Contains(t, url, "response_type=code") + assert.Contains(t, url, "client_id=client-123") + assert.Contains(t, url, "redirect_uri=") + assert.Contains(t, url, "state=state-uuid") + assert.Contains(t, url, "code_challenge=challenge-value") + assert.Contains(t, url, "code_challenge_method=S256") + assert.Contains(t, url, "scope=sso%3Aaccount%3Aaccess") + // Ensure it's "scope" (singular per RFC 6749), not "scopes" + assert.NotContains(t, url, "scopes=") +} + +func TestCallbackHandler_Success(t *testing.T) { + result := make(chan authCallbackResult, 1) + handler := newCallbackHandler("expected-state", result) + + req := httptest.NewRequest("GET", "/oauth/callback?code=auth-code-123&state=expected-state", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "Authentication Successful") + assert.Equal(t, "no-store", w.Header().Get("Cache-Control")) + assert.Equal(t, "default-src 'none'; style-src 'unsafe-inline'", w.Header().Get("Content-Security-Policy")) + assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options")) + assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options")) + + r := <-result + assert.NoError(t, r.err) + assert.Equal(t, "auth-code-123", r.code) +} + +func TestCallbackHandler_StateMismatch(t *testing.T) { + result := make(chan authCallbackResult, 1) + handler := newCallbackHandler("expected-state", result) + + req := httptest.NewRequest("GET", "/oauth/callback?code=auth-code-123&state=wrong-state", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "Authentication Failed") + + r := <-result + assert.Error(t, r.err) + assert.Contains(t, r.err.Error(), "state parameter mismatch") +} + +func TestCallbackHandler_OAuthError(t *testing.T) { + result := make(chan authCallbackResult, 1) + handler := newCallbackHandler("expected-state", result) + + req := httptest.NewRequest("GET", "/oauth/callback?error=access_denied&error_description=User+denied+access", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "Authentication Failed") + + r := <-result + assert.Error(t, r.err) + assert.Contains(t, r.err.Error(), "access_denied") +} + +func TestCallbackHandler_MissingCode(t *testing.T) { + result := make(chan authCallbackResult, 1) + handler := newCallbackHandler("expected-state", result) + + req := httptest.NewRequest("GET", "/oauth/callback?state=expected-state", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + + r := <-result + assert.Error(t, r.err) + assert.Contains(t, r.err.Error(), "no authorization code received") +} + +func TestCallbackHandler_XSSPrevention(t *testing.T) { + result := make(chan authCallbackResult, 1) + handler := newCallbackHandler("expected-state", result) + + // Attempt XSS via error parameters + req := httptest.NewRequest("GET", "/oauth/callback?error=&error_description=", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + body := w.Body.String() + // html/template should escape the XSS payloads so they render as text, not HTML + assert.NotContains(t, body, "