Skip to content

Commit 500a688

Browse files
authored
feat: Validate token ownership (#388)
1 parent a7de988 commit 500a688

File tree

3 files changed

+123
-4
lines changed

3 files changed

+123
-4
lines changed

diode-server/auth/mocks/tokenownershipprovider.go

Lines changed: 51 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

diode-server/auth/server.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,15 @@ func statusFromError(err error) int {
8989
return http.StatusInternalServerError
9090
}
9191

92+
// TokenOwnershipValidationData contains data for validating token ownership
93+
type TokenOwnershipValidationData struct {
94+
Headers http.Header
95+
}
96+
9297
// TokenOwnershipProvider determines the owner of a token
9398
type TokenOwnershipProvider interface {
9499
TokenOwnerID(ctx context.Context, token string) (string, error)
100+
ValidateTokenOwnership(data TokenOwnershipValidationData, claims jwt.MapClaims) error
95101
}
96102

97103
// DefaultTokenOwner is a default implementation of TokenOwnershipProvider
@@ -102,6 +108,11 @@ func (p *DefaultTokenOwner) TokenOwnerID(_ context.Context, _ string) (string, e
102108
return DefaultTokenOwnerID, nil
103109
}
104110

111+
// ValidateTokenOwnership validates the ownership of a token
112+
func (p *DefaultTokenOwner) ValidateTokenOwnership(_ TokenOwnershipValidationData, _ jwt.MapClaims) error {
113+
return nil
114+
}
115+
105116
// ClientInfoDecorator attaches additional information to a client info
106117
type ClientInfoDecorator interface {
107118
VisitClientInfo(ctx context.Context, clientInfo *ClientInfo) error
@@ -209,6 +220,13 @@ func (s *Server) introspect(w http.ResponseWriter, r *http.Request) {
209220
return
210221
}
211222

223+
err = s.tokenOwnership.ValidateTokenOwnership(TokenOwnershipValidationData{Headers: r.Header}, claims)
224+
if err != nil {
225+
s.logger.Error("failed to validate token ownership", "error", err)
226+
w.WriteHeader(http.StatusForbidden)
227+
return
228+
}
229+
212230
resp := IntrospectResponse{
213231
Active: true,
214232
Subject: getStringClaim(claims, "sub"),

diode-server/auth/server_test.go

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"log/slog"
910
"net"
@@ -42,14 +43,25 @@ func (p MockTokenParser) Parse(token string, _ jwt.Keyfunc) (*jwt.Token, error)
4243
return nil, fmt.Errorf("token not found")
4344
}
4445

46+
type ownerInvalid struct{}
47+
48+
func (o ownerInvalid) TokenOwnerID(_ context.Context, _ string) (string, error) {
49+
return auth.DefaultTokenOwnerID, nil
50+
}
51+
52+
func (o ownerInvalid) ValidateTokenOwnership(_ auth.TokenOwnershipValidationData, _ jwt.MapClaims) error {
53+
return errors.New("invalid token owner")
54+
}
55+
4556
func TestNewServer(t *testing.T) {
4657
ctx := context.Background()
4758

4859
setupEnv()
4960
defer teardownEnv()
5061

5162
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: false}))
52-
server, err := auth.NewServer(ctx, logger, InvalidParser{}, nil, nil)
63+
defaultOwnership := &auth.DefaultTokenOwner{}
64+
server, err := auth.NewServer(ctx, logger, InvalidParser{}, nil, defaultOwnership)
5365
require.NoError(t, err)
5466
require.NotNil(t, server)
5567

@@ -96,7 +108,8 @@ func TestIntrospectForInvalidTokens(t *testing.T) {
96108
}()
97109

98110
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: false}))
99-
server, err := auth.NewServer(ctx, logger, InvalidParser{}, nil, nil)
111+
defaultOwnership := &auth.DefaultTokenOwner{}
112+
server, err := auth.NewServer(ctx, logger, InvalidParser{}, nil, defaultOwnership)
100113
require.NoError(t, err)
101114
require.NotNil(t, server)
102115

@@ -138,6 +151,7 @@ func TestIntrospectForValidTokens(t *testing.T) {
138151
name string
139152
token string
140153
tokenParser auth.TokenParser
154+
invalidOwner bool
141155
expectedStatus int
142156
expectedAudience []string
143157
expectedSubject string
@@ -201,6 +215,35 @@ func TestIntrospectForValidTokens(t *testing.T) {
201215
expectedClientID: "client123",
202216
expectedUsername: "testuser",
203217
},
218+
{
219+
name: "Valid Token with invalid owner",
220+
token: testToken,
221+
tokenParser: &MockTokenParser{
222+
tokenMap: map[string]jwt.Token{
223+
testToken: {
224+
Claims: jwt.MapClaims{
225+
"iss": "https://auth.example.com",
226+
"sub": "user123",
227+
"aud": "api",
228+
"exp": time.Now().Add(time.Hour).Unix(),
229+
"iat": time.Now().Unix(),
230+
"client_id": "client123",
231+
"scope": "read write",
232+
"username": "testuser",
233+
},
234+
Valid: true,
235+
},
236+
},
237+
},
238+
invalidOwner: true,
239+
expectedStatus: http.StatusForbidden,
240+
expectedAudience: []string{"api"},
241+
expectedSubject: "user123",
242+
expectedScope: "read write",
243+
expectedIssuer: "https://auth.example.com",
244+
expectedClientID: "client123",
245+
expectedUsername: "testuser",
246+
},
204247
}
205248

206249
// Setup a test server to mock the OAuth2 server
@@ -238,7 +281,11 @@ func TestIntrospectForValidTokens(t *testing.T) {
238281
t.Run(test.name, func(t *testing.T) {
239282
ctx := context.Background()
240283

241-
server, err := auth.NewServer(ctx, logger, test.tokenParser, nil, nil)
284+
var ownerProvider auth.TokenOwnershipProvider = &auth.DefaultTokenOwner{}
285+
if test.invalidOwner {
286+
ownerProvider = ownerInvalid{}
287+
}
288+
server, err := auth.NewServer(ctx, logger, test.tokenParser, nil, ownerProvider)
242289
require.NoError(t, err)
243290
require.NotNil(t, server)
244291

@@ -252,7 +299,10 @@ func TestIntrospectForValidTokens(t *testing.T) {
252299
resp, err := makeIntrospectRequest(testServer.URL, test.token)
253300

254301
require.NoError(t, err)
255-
require.Equal(t, http.StatusOK, resp.StatusCode)
302+
require.Equal(t, test.expectedStatus, resp.StatusCode)
303+
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
304+
return
305+
}
256306

257307
defer func() {
258308
_ = resp.Body.Close()

0 commit comments

Comments
 (0)