diff --git a/cmd/nginx-ingress/aws.go b/cmd/nginx-ingress/aws.go index 9bb937932f..a6ccc6e0ec 100644 --- a/cmd/nginx-ingress/aws.go +++ b/cmd/nginx-ingress/aws.go @@ -15,7 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/marketplacemetering" "github.com/aws/aws-sdk-go-v2/service/marketplacemetering/types" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" ) var ( @@ -24,6 +24,12 @@ var ( pubKeyString string ) +var ( + ErrMissingProductCode = errors.New("token doesn't include the ProductCode") + ErrMissingNonce = errors.New("token doesn't include the Nonce") + ErrMissingKeyVersion = errors.New("token doesn't include the PublicKeyVersion") +) + func init() { startupCheckFn = checkAWSEntitlement } @@ -95,21 +101,18 @@ type claims struct { jwt.RegisteredClaims } -func (c claims) Valid() error { +var _ jwt.ClaimsValidator = (*claims)(nil) + +func (c claims) Validate() error { if c.Nonce == "" { - return jwt.NewValidationError("token doesn't include the Nonce", jwt.ValidationErrorClaimsInvalid) + return ErrMissingNonce } if c.ProductCode == "" { - return jwt.NewValidationError("token doesn't include the ProductCode", jwt.ValidationErrorClaimsInvalid) + return ErrMissingProductCode } if c.PublicKeyVersion == 0 { - return jwt.NewValidationError("token doesn't include the PublicKeyVersion", jwt.ValidationErrorClaimsInvalid) + return ErrMissingKeyVersion } - - if err := c.RegisteredClaims.Valid(); err != nil { - return err - } - return nil } diff --git a/cmd/nginx-ingress/aws_test.go b/cmd/nginx-ingress/aws_test.go index 1c151f0b5d..65b7d41a58 100644 --- a/cmd/nginx-ingress/aws_test.go +++ b/cmd/nginx-ingress/aws_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" ) func TestValidClaims(t *testing.T) { @@ -21,69 +21,72 @@ func TestValidClaims(t *testing.T) { IssuedAt: &iat, }, } - if err := c.Valid(); err != nil { + v := jwt.NewValidator( + jwt.WithIssuedAt(), + ) + if err := v.Validate(c); err != nil { t.Fatalf("Failed to verify claims, wanted: %v got %v", nil, err) } } func TestInvalidClaims(t *testing.T) { - badClaims := []struct { - c claims - expectedError error + type fields struct { + leeway time.Duration + timeFunc func() time.Time + expectedAud string + expectAllAud []string + expectedIss string + expectedSub string + } + type args struct { + claims jwt.Claims + } + tests := []struct { + name string + fields fields + args args + wantErr error }{ { - claims{ - "", - 1, - "nonce", - jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * -1)), - }, - }, - errors.New("token doesn't include the ProductCode"), + name: "missing ProductCode", + fields: fields{}, + args: args{jwt.RegisteredClaims{}}, + wantErr: ErrMissingProductCode, }, { - claims{ - "productCode", - 1, - "", - jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * -1)), - }, - }, - errors.New("token doesn't include the Nonce"), + name: "missing Nonce", + fields: fields{}, + args: args{jwt.RegisteredClaims{}}, + wantErr: ErrMissingNonce, }, { - claims{ - "productCode", - 0, - "nonce", - jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * -1)), - }, - }, - errors.New("token doesn't include the PublicKeyVersion"), + name: "missing PublicKeyVersion", + fields: fields{}, + args: args{jwt.RegisteredClaims{}}, + wantErr: ErrMissingKeyVersion, }, { - claims{ - "test", - 1, - "nonce", - jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * +2)), - }, - }, - errors.New("token used before issued"), + name: "iat is in the future", + fields: fields{}, + args: args{jwt.RegisteredClaims{IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour * +2))}}, + wantErr: jwt.ErrTokenUsedBeforeIssued, }, } - for _, badC := range badClaims { - - err := badC.c.Valid() - if err == nil { - t.Errorf("Valid() returned no error when it should have returned error %q", badC.expectedError) - } else if err.Error() != badC.expectedError.Error() { - t.Errorf("Valid() returned error %q when it should have returned error %q", err, badC.expectedError) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := jwt.NewValidator( + jwt.WithLeeway(tt.fields.leeway), + jwt.WithTimeFunc(tt.fields.timeFunc), + jwt.WithIssuedAt(), + jwt.WithAudience(tt.fields.expectedAud), + jwt.WithAllAudiences(tt.fields.expectAllAud...), + jwt.WithIssuer(tt.fields.expectedIss), + jwt.WithSubject(tt.fields.expectedSub), + ) + if err := v.Validate(tt.args.claims); (err != nil) && !errors.Is(err, tt.wantErr) { + t.Errorf("validator.Validate() error = %v, wantErr = %v", err, tt.wantErr) + } + }) } } diff --git a/go.mod b/go.mod index cf7f5e8585..ebcf30bd9e 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/cert-manager/cert-manager v1.18.2 github.com/dlclark/regexp2 v1.11.5 github.com/gkampitakis/go-snaps v0.5.15 - github.com/golang-jwt/jwt/v4 v4.5.2 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/go-cmp v0.7.0 github.com/gruntwork-io/terratest v0.50.0 github.com/jinzhu/copier v0.4.0 diff --git a/go.sum b/go.sum index bed431ca64..5509008052 100644 --- a/go.sum +++ b/go.sum @@ -169,6 +169,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/gonvenience/bunt v1.3.5 h1:wSQquifvwEWtzn27k1ngLfeLaStyt0k1b/K6TrlCNAs=