@@ -88,18 +88,29 @@ func (o *OIDCConnect) isValidState(ctx context.Context, req *Request, url *url.U
88
88
// Do we have stateid stored in querystring
89
89
var state * store.OIDCState
90
90
91
- stateToken := url .Query ().Get (stateQueryParamName )
92
-
93
- stateByte , err := o .Cache .Get (stateToken )
94
- if err == nil {
95
- state = store .ConvertToType (stateByte )
96
- }
91
+ // Check if there's a bearer token in the Authorization header
92
+ authHeader := req .Request .Header .Get ("Authorization" )
93
+ if strings .HasPrefix (authHeader , "Bearer " ) {
94
+ // Extract the token
95
+ token := strings .TrimPrefix (authHeader , "Bearer " )
96
+ // Create new state with the token
97
+ state = store .NewState ()
98
+ state .Status = store .StatusTokenReady
99
+ state .IDToken = token
100
+ } else {
101
+ stateToken := url .Query ().Get (stateQueryParamName )
97
102
98
- // State not found, try to retrieve from cookies.
99
- if state == nil {
100
- state , _ = o .getStateFromCookie (req )
103
+ stateByte , err := o .Cache .Get (stateToken )
104
+ if err == nil {
105
+ state = store .ConvertToType (stateByte )
106
+ }
107
+ // State not found, try to retrieve from cookies.
108
+ if state == nil {
109
+ state , _ = o .getStateFromCookie (req )
110
+ }
101
111
}
102
112
113
+ // State exists, proceed with token validation.
103
114
// State exists, proceed with token validation.
104
115
if state != nil {
105
116
// Re-initialize provider to refresh the context, this seems like bugs with coreos go-oidc module.
@@ -116,7 +127,7 @@ func (o *OIDCConnect) isValidState(ctx context.Context, req *Request, url *url.U
116
127
117
128
resp .Response .Header .Add (oauthTokenName , string (stateJSON ))
118
129
119
- if err := o .Cache .Delete (state .OAuthState ); err != nil {
130
+ if err := o .Cache .Delete (state .OAuthState ); err != nil && err != bigcache . ErrEntryNotFound {
120
131
o .Log .Error (err , "error deleting state" )
121
132
}
122
133
@@ -135,6 +146,13 @@ func (o *OIDCConnect) loginHandler(u *url.URL) Response {
135
146
state .RequestPath = path .Join (u .Host , u .Path )
136
147
state .Scheme = u .Scheme
137
148
149
+ config := o .oauth2Config ()
150
+
151
+ redirectURL := fmt .Sprintf ("%s://%s%s" , u .Scheme , u .Host , u .Path )
152
+ if redirectURL != config .RedirectURL && matchDomain (redirectURL , o .OidcConfig .AuthorizedRedirectDomains ) {
153
+ config .RedirectURL = redirectURL
154
+ }
155
+
138
156
authCodeURL := o .oauth2Config ().AuthCodeURL (state .OAuthState )
139
157
140
158
byteState := store .ConvertToByte (state )
@@ -205,6 +223,9 @@ func (o *OIDCConnect) callbackHandler(ctx context.Context, u *url.URL) (Response
205
223
resp .Response .Header .Add ("Location" ,
206
224
fmt .Sprintf ("%s://%s?%s=%s" , state .Scheme , state .RequestPath , stateQueryParamName , state .OAuthState ))
207
225
226
+ stateJSON , _ := json .Marshal (state )
227
+ resp .Response .Header .Add ("Set-Cookie" , fmt .Sprintf ("%s=%s; Path=/; Secure; SameSite=Lax" , oauthTokenName , string (stateJSON )))
228
+
208
229
return resp , nil
209
230
}
210
231
@@ -246,7 +267,6 @@ func (o *OIDCConnect) getStateFromCookie(req *Request) (*store.OIDCState, error)
246
267
// Check through and get the right cookies
247
268
if len (cookieVal ) > 0 {
248
269
cookies := strings .Split (cookieVal , ";" )
249
-
250
270
for _ , c := range cookies {
251
271
c = strings .TrimSpace (c )
252
272
if strings .HasPrefix (c , oauthTokenName ) {
@@ -280,7 +300,8 @@ func (o *OIDCConnect) oauth2Config() *oauth2.Config {
280
300
ClientSecret : o .OidcConfig .ClientSecret ,
281
301
Endpoint : o .provider .Endpoint (),
282
302
Scopes : o .OidcConfig .Scopes ,
283
- RedirectURL : o .OidcConfig .RedirectURL + o .OidcConfig .RedirectPath ,
303
+
304
+ RedirectURL : o .OidcConfig .RedirectURL + o .OidcConfig .RedirectPath ,
284
305
}
285
306
}
286
307
@@ -313,3 +334,60 @@ func parseURL(req *Request) *url.URL {
313
334
314
335
return u
315
336
}
337
+
338
+ // matchDomain checks if a domain matches any of the allowed patterns.
339
+ func matchDomain (domain string , allowedPatterns []string ) bool {
340
+ for _ , pattern := range allowedPatterns {
341
+ if matchPattern (domain , pattern ) {
342
+ return true
343
+ }
344
+ }
345
+ return false
346
+ }
347
+
348
+ // matchPattern checks if a domain matches a single pattern with wildcards.
349
+ func matchPattern (domain , pattern string ) bool {
350
+ // Split the pattern and domain into parts.
351
+ patternParts := strings .Split (pattern , "." )
352
+ domainParts := strings .Split (domain , "." )
353
+
354
+ // If the number of parts doesn't match, it's not a match.
355
+ if len (patternParts ) != len (domainParts ) {
356
+ return false
357
+ }
358
+
359
+ // Check each part of the pattern against the domain.
360
+ for i := range patternParts {
361
+ if ! matchPart (domainParts [i ], patternParts [i ]) {
362
+ return false
363
+ }
364
+ }
365
+
366
+ return true
367
+ }
368
+
369
+ // matchPart checks if a single part of the domain matches the pattern part.
370
+ func matchPart (domainPart , patternPart string ) bool {
371
+ // If the pattern part is a wildcard, it matches anything.
372
+ if patternPart == "*" {
373
+ return true
374
+ }
375
+
376
+ // Split the pattern part by the wildcard.
377
+ parts := strings .Split (patternPart , "*" )
378
+
379
+ // Check if the domain part matches the pattern parts in sequence.
380
+ pos := 0
381
+ for _ , part := range parts {
382
+ if part == "" {
383
+ continue
384
+ }
385
+ index := strings .Index (domainPart [pos :], part )
386
+ if index == - 1 {
387
+ return false
388
+ }
389
+ pos += index + len (part )
390
+ }
391
+
392
+ return true
393
+ }
0 commit comments