Skip to content

Commit

Permalink
forward name to oauth2 context and provide an accessor
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Porto Carrero <[email protected]>
  • Loading branch information
casualjim committed Aug 17, 2019
1 parent 7a84b65 commit c6fb0f1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
18 changes: 17 additions & 1 deletion security/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ type secCtxKey uint8

const (
failedBasicAuth secCtxKey = iota
oauth2SchemeName
)

func FailedBasicAuth(r *http.Request) string {
Expand All @@ -89,6 +90,18 @@ func FailedBasicAuthCtx(ctx context.Context) string {
return v
}

func OAuth2SchemeName(r *http.Request) string {
return OAuth2SchemeNameCtx(r.Context())
}

func OAuth2SchemeNameCtx(ctx context.Context) string {
v, ok := ctx.Value(oauth2SchemeName).(string)
if !ok {
return ""
}
return v
}

// BasicAuth creates a basic auth authenticator with the provided authentication function
func BasicAuth(authenticate UserPassAuthentication) runtime.Authenticator {
return BasicAuthRealm(DefaultRealmName, authenticate)
Expand Down Expand Up @@ -224,6 +237,8 @@ func BearerAuth(name string, authenticate ScopedTokenAuthentication) runtime.Aut
return false, nil, nil
}

rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
*r.Request = *r.Request.WithContext(rctx)
p, err := authenticate(token, r.RequiredScopes)
return true, p, err
})
Expand Down Expand Up @@ -252,7 +267,8 @@ func BearerAuthCtx(name string, authenticate ScopedTokenAuthenticationCtx) runti
return false, nil, nil
}

ctx, p, err := authenticate(r.Request.Context(), token, r.RequiredScopes)
rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
ctx, p, err := authenticate(rctx, token, r.RequiredScopes)
*r.Request = *r.Request.WithContext(ctx)
return true, p, err
})
Expand Down
8 changes: 8 additions & 0 deletions security/bearer_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func TestValidBearerAuth(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, "admin", usr)
assert.NoError(t, err)
assert.Equal(t, OAuth2SchemeName(req1), "owners_auth")

req2, _ := http.NewRequest("GET", "/blah", nil)
req2.Header.Set("Authorization", "Bearer token123")
Expand All @@ -37,6 +38,7 @@ func TestValidBearerAuth(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, "admin", usr)
assert.NoError(t, err)
assert.Equal(t, OAuth2SchemeName(req2), "owners_auth")

body := url.Values(map[string][]string{})
body.Set("access_token", "token123")
Expand All @@ -47,6 +49,7 @@ func TestValidBearerAuth(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, "admin", usr)
assert.NoError(t, err)
assert.Equal(t, OAuth2SchemeName(req3), "owners_auth")

mpbody := bytes.NewBuffer(nil)
writer := multipart.NewWriter(mpbody)
Expand All @@ -59,6 +62,7 @@ func TestValidBearerAuth(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, "admin", usr)
assert.NoError(t, err)
assert.Equal(t, OAuth2SchemeName(req4), "owners_auth")
}

func TestInvalidBearerAuth(t *testing.T) {
Expand Down Expand Up @@ -162,6 +166,7 @@ func TestValidBearerAuthCtx(t *testing.T) {
assert.Equal(t, wisdom, req1.Context().Value(original))
assert.Equal(t, extraWisdom, req1.Context().Value(extra))
assert.Nil(t, req1.Context().Value(reason))
assert.Equal(t, OAuth2SchemeName(req1), "owners_auth")

req2, _ := http.NewRequest("GET", "/blah", nil)
req2 = req2.WithContext(context.WithValue(req2.Context(), original, wisdom))
Expand All @@ -174,6 +179,7 @@ func TestValidBearerAuthCtx(t *testing.T) {
assert.Equal(t, wisdom, req2.Context().Value(original))
assert.Equal(t, extraWisdom, req2.Context().Value(extra))
assert.Nil(t, req2.Context().Value(reason))
assert.Equal(t, OAuth2SchemeName(req2), "owners_auth")

body := url.Values(map[string][]string{})
body.Set("access_token", "token123")
Expand All @@ -188,6 +194,7 @@ func TestValidBearerAuthCtx(t *testing.T) {
assert.Equal(t, wisdom, req3.Context().Value(original))
assert.Equal(t, extraWisdom, req3.Context().Value(extra))
assert.Nil(t, req3.Context().Value(reason))
assert.Equal(t, OAuth2SchemeName(req3), "owners_auth")

mpbody := bytes.NewBuffer(nil)
writer := multipart.NewWriter(mpbody)
Expand All @@ -204,6 +211,7 @@ func TestValidBearerAuthCtx(t *testing.T) {
assert.Equal(t, wisdom, req4.Context().Value(original))
assert.Equal(t, extraWisdom, req4.Context().Value(extra))
assert.Nil(t, req4.Context().Value(reason))
assert.Equal(t, OAuth2SchemeName(req4), "owners_auth")
}

func TestInvalidBearerAuthCtx(t *testing.T) {
Expand Down

0 comments on commit c6fb0f1

Please sign in to comment.