Skip to content

Commit 60c2e7f

Browse files
tanyav2claude
andcommitted
Revert "feat(shim): forward verified JWT subject to upstream"
This reverts commit b469dab. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 73f3f15 commit 60c2e7f

7 files changed

Lines changed: 46 additions & 123 deletions

File tree

tinfoil/cmd/shim/api.go

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,6 @@ const (
8888
errTypeServer = "server_error"
8989
)
9090

91-
// subjectHeader carries the authenticated principal (JWT `sub`) from the shim to
92-
// the upstream workload after a JWT access token is verified in-enclave. The
93-
// upstream trusts it because the shim is the only path to it and unconditionally
94-
// strips any client-supplied value before setting its own. Must match the header
95-
// the upstream (confidential-model-router) reads.
96-
const subjectHeader = "X-Tinfoil-Subject"
97-
9891
// Client-facing error messages, aligned with OpenAI's standard error messages
9992
// where applicable. See https://platform.openai.com/docs/guides/error-codes
10093
const (
@@ -221,20 +214,7 @@ func NewShimServer(
221214
}
222215

223216
proxyHandler := ehbpMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
224-
// Never trust a client-supplied identity header: the shim is the sole
225-
// authority for subjectHeader and (re)sets it only after verifying a
226-
// JWT below. Strip it unconditionally, including on unauthenticated
227-
// paths, so a forged value can never reach the upstream.
228-
r.Header.Del(subjectHeader)
229-
230217
apiKey := extractBearerToken(r.Header.Get("Authorization"))
231-
232-
// rateLimitID identifies the caller for the shim's local rate limiter.
233-
// It defaults to the bearer credential and is replaced with the verified
234-
// JWT subject when available, so a user's bucket is stable across the
235-
// short-lived token's refreshes (and across multiple tokens).
236-
rateLimitID := apiKey
237-
238218
if validator != nil && requiresAuth(config.AuthenticatedEndpoints, r.URL.Path) {
239219
if len(apiKey) == 0 {
240220
writeJSONError(w, errMsgAPIKeyRequired, errTypeInvalidRequest, http.StatusUnauthorized)
@@ -248,28 +228,19 @@ func NewShimServer(
248228
Path: r.URL.Path,
249229
}
250230

251-
res, err := validator.Validate(validationReq)
252-
if err != nil {
231+
if err := validator.Validate(validationReq); err != nil {
253232
log.Printf("Warning: failed to validate API key: %v", err)
254233
writeValidationFailure(w, err)
255234
return
256235
}
257-
258-
// Forward the verified principal to the upstream so it can attribute
259-
// usage and rate limits to a stable identity rather than the
260-
// rotating bearer token. Empty for opaque keys (online validator).
261-
if res.Subject != "" {
262-
r.Header.Set(subjectHeader, res.Subject)
263-
rateLimitID = res.Subject
264-
}
265236
}
266237

267238
if rateLimiter != nil {
268239
if apiKey == "" {
269240
writeJSONError(w, errMsgAPIKeyRequired, errTypeInvalidRequest, http.StatusUnauthorized)
270241
return
271242
}
272-
limiter := rateLimiter.Limit(rateLimitID)
243+
limiter := rateLimiter.Limit(apiKey)
273244
if !limiter.Allow() {
274245
writeJSONError(w, errMsgRateLimited, errTypeInvalidRequest, http.StatusTooManyRequests)
275246
return

tinfoil/internal/key/jwt/jwt.go

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -171,57 +171,55 @@ func NewValidator(jwksURL, issuer, audience, requiredScope string) *Validator {
171171
}
172172
}
173173

174-
func (v *Validator) Validate(req key.Request) (key.Result, error) {
174+
func (v *Validator) Validate(req key.Request) error {
175175
if !isAccessTokenJWT(req.APIKey) {
176-
return key.Result{}, key.ErrUnsupportedToken
176+
return key.ErrUnsupportedToken
177177
}
178178

179179
token, err := josejwt.ParseSigned(req.APIKey, []jose.SignatureAlgorithm{jose.EdDSA})
180180
if err != nil || len(token.Headers) == 0 {
181-
return key.Result{}, &key.ValidationError{StatusCode: http.StatusUnauthorized}
181+
return &key.ValidationError{StatusCode: http.StatusUnauthorized}
182182
}
183183

184184
// RFC 9068 registers the access-token type as "at+jwt"; RFC 7515 also
185185
// permits the equivalent media type carrying an "application/" prefix, so
186186
// accept both forms case-insensitively.
187187
typ, _ := token.Headers[0].ExtraHeaders[jose.HeaderType].(string)
188188
if normalizeType(typ) != accessTokenType {
189-
return key.Result{}, &key.ValidationError{StatusCode: http.StatusUnauthorized}
189+
return &key.ValidationError{StatusCode: http.StatusUnauthorized}
190190
}
191191

192192
signingKey, ok := v.keys.lookup(token.Headers[0].KeyID)
193193
if !ok {
194194
v.keys.refreshIfAllowed()
195195
signingKey, ok = v.keys.lookup(token.Headers[0].KeyID)
196196
if !ok {
197-
return key.Result{}, &key.ValidationError{StatusCode: http.StatusUnauthorized}
197+
return &key.ValidationError{StatusCode: http.StatusUnauthorized}
198198
}
199199
}
200200

201201
var claims josejwt.Claims
202202
var ext accessTokenClaims
203203
if err := token.Claims(signingKey, &claims, &ext); err != nil {
204-
return key.Result{}, &key.ValidationError{StatusCode: http.StatusUnauthorized}
204+
return &key.ValidationError{StatusCode: http.StatusUnauthorized}
205205
}
206206
if claims.Subject == "" || claims.Expiry == nil || claims.IssuedAt == nil || claims.ID == "" || ext.ClientID == "" {
207-
return key.Result{}, &key.ValidationError{StatusCode: http.StatusUnauthorized}
207+
return &key.ValidationError{StatusCode: http.StatusUnauthorized}
208208
}
209209

210210
if err := claims.Validate(josejwt.Expected{
211211
Issuer: v.issuer,
212212
AnyAudience: josejwt.Audience{v.audience},
213213
Time: time.Now(),
214214
}); err != nil {
215-
return key.Result{}, &key.ValidationError{StatusCode: http.StatusUnauthorized}
215+
return &key.ValidationError{StatusCode: http.StatusUnauthorized}
216216
}
217217

218218
if !scopeContains(ext.Scope, v.scope) {
219-
return key.Result{}, &key.ValidationError{StatusCode: http.StatusForbidden}
219+
return &key.ValidationError{StatusCode: http.StatusForbidden}
220220
}
221221

222-
// The verified subject lets the shim forward a stable identity to the
223-
// upstream so usage/rate limits attach to the user, not the rotating token.
224-
return key.Result{Subject: claims.Subject}, nil
222+
return nil
225223
}
226224

227225
// isAccessTokenJWT reports whether s is explicitly typed as an access-token

tinfoil/internal/key/jwt/jwt_test.go

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -91,28 +91,16 @@ func chatRequest(token string) key.Request {
9191
return key.Request{APIKey: token, Path: "/v1/chat/completions"}
9292
}
9393

94-
// validateErr runs the validator and returns only the error, for the many
95-
// rejection cases that do not care about the Result. The accepted-token tests
96-
// call Validate directly to assert the surfaced subject.
97-
func validateErr(val *Validator, req key.Request) error {
98-
_, err := val.Validate(req)
99-
return err
100-
}
101-
10294
func TestValidateAcceptsValidToken(t *testing.T) {
10395
pub, priv, _ := ed25519.GenerateKey(nil)
10496
srv := jwksServer(t, pub, testKID)
10597
defer srv.Close()
10698
v := newTestValidator(t, srv.URL)
10799

108100
token := mintToken(t, priv, testKID, "at+jwt", validClaims(time.Now()), RequiredScope)
109-
res, err := v.Validate(chatRequest(token))
110-
if err != nil {
101+
if err := v.Validate(chatRequest(token)); err != nil {
111102
t.Fatalf("expected valid token, got %v", err)
112103
}
113-
if res.Subject != "user_1" {
114-
t.Fatalf("Subject = %q, want user_1", res.Subject)
115-
}
116104
}
117105

118106
func TestValidateFallsThroughForOpaqueKey(t *testing.T) {
@@ -121,7 +109,7 @@ func TestValidateFallsThroughForOpaqueKey(t *testing.T) {
121109
defer srv.Close()
122110
v := newTestValidator(t, srv.URL)
123111

124-
err := validateErr(v, key.Request{APIKey: "chat_abcdef"})
112+
err := v.Validate(key.Request{APIKey: "chat_abcdef"})
125113
if !errors.Is(err, key.ErrUnsupportedToken) {
126114
t.Fatalf("expected ErrUnsupportedToken, got %v", err)
127115
}
@@ -133,7 +121,7 @@ func TestValidateFallsThroughForDottedOpaqueKey(t *testing.T) {
133121
defer srv.Close()
134122
v := newTestValidator(t, srv.URL)
135123

136-
err := validateErr(v, key.Request{APIKey: "opaque.with.dots"})
124+
err := v.Validate(key.Request{APIKey: "opaque.with.dots"})
137125
if !errors.Is(err, key.ErrUnsupportedToken) {
138126
t.Fatalf("expected ErrUnsupportedToken, got %v", err)
139127
}
@@ -148,7 +136,7 @@ func TestValidateRejectsWrongAudience(t *testing.T) {
148136
claims := validClaims(time.Now())
149137
claims.Audience = josejwt.Audience{"https://example.com"}
150138
token := mintToken(t, priv, testKID, "at+jwt", claims, RequiredScope)
151-
expectStatus(t, validateErr(v, chatRequest(token)), http.StatusUnauthorized)
139+
expectStatus(t, v.Validate(chatRequest(token)), http.StatusUnauthorized)
152140
}
153141

154142
func TestValidateRejectsExpired(t *testing.T) {
@@ -158,7 +146,7 @@ func TestValidateRejectsExpired(t *testing.T) {
158146
v := newTestValidator(t, srv.URL)
159147

160148
token := mintToken(t, priv, testKID, "at+jwt", validClaims(time.Now().Add(-time.Hour)), RequiredScope)
161-
expectStatus(t, validateErr(v, chatRequest(token)), http.StatusUnauthorized)
149+
expectStatus(t, v.Validate(chatRequest(token)), http.StatusUnauthorized)
162150
}
163151

164152
func TestValidateRejectsMissingExpiration(t *testing.T) {
@@ -170,7 +158,7 @@ func TestValidateRejectsMissingExpiration(t *testing.T) {
170158
claims := validClaims(time.Now())
171159
claims.Expiry = nil
172160
token := mintToken(t, priv, testKID, "at+jwt", claims, RequiredScope)
173-
expectStatus(t, validateErr(v, chatRequest(token)), http.StatusUnauthorized)
161+
expectStatus(t, v.Validate(chatRequest(token)), http.StatusUnauthorized)
174162
}
175163

176164
func TestValidateRejectsMissingScope(t *testing.T) {
@@ -180,7 +168,7 @@ func TestValidateRejectsMissingScope(t *testing.T) {
180168
v := newTestValidator(t, srv.URL)
181169

182170
token := mintToken(t, priv, testKID, "at+jwt", validClaims(time.Now()), "models:read")
183-
expectStatus(t, validateErr(v, chatRequest(token)), http.StatusForbidden)
171+
expectStatus(t, v.Validate(chatRequest(token)), http.StatusForbidden)
184172
}
185173

186174
func TestValidateRejectsWrongIssuer(t *testing.T) {
@@ -192,7 +180,7 @@ func TestValidateRejectsWrongIssuer(t *testing.T) {
192180
claims := validClaims(time.Now())
193181
claims.Issuer = "https://evil.example.com"
194182
token := mintToken(t, priv, testKID, "at+jwt", claims, RequiredScope)
195-
expectStatus(t, validateErr(v, chatRequest(token)), http.StatusUnauthorized)
183+
expectStatus(t, v.Validate(chatRequest(token)), http.StatusUnauthorized)
196184
}
197185

198186
func TestValidateFallsThroughForWrongType(t *testing.T) {
@@ -202,7 +190,7 @@ func TestValidateFallsThroughForWrongType(t *testing.T) {
202190
v := newTestValidator(t, srv.URL)
203191

204192
token := mintToken(t, priv, testKID, "JWT", validClaims(time.Now()), RequiredScope)
205-
err := validateErr(v, chatRequest(token))
193+
err := v.Validate(chatRequest(token))
206194
if !errors.Is(err, key.ErrUnsupportedToken) {
207195
t.Fatalf("expected ErrUnsupportedToken, got %v", err)
208196
}
@@ -218,7 +206,7 @@ func TestValidateRejectsForeignSignature(t *testing.T) {
218206
// verification against the published key must fail.
219207
_, foreignPriv, _ := ed25519.GenerateKey(nil)
220208
token := mintToken(t, foreignPriv, testKID, "at+jwt", validClaims(time.Now()), RequiredScope)
221-
expectStatus(t, validateErr(v, chatRequest(token)), http.StatusUnauthorized)
209+
expectStatus(t, v.Validate(chatRequest(token)), http.StatusUnauthorized)
222210
}
223211

224212
func TestValidateAcceptsApplicationPrefixType(t *testing.T) {
@@ -229,7 +217,7 @@ func TestValidateAcceptsApplicationPrefixType(t *testing.T) {
229217

230218
// RFC 9068 / RFC 7515 permit the media type with an "application/" prefix.
231219
token := mintToken(t, priv, testKID, "application/at+jwt", validClaims(time.Now()), RequiredScope)
232-
if err := validateErr(v, chatRequest(token)); err != nil {
220+
if err := v.Validate(chatRequest(token)); err != nil {
233221
t.Fatalf("expected application/at+jwt to be accepted, got %v", err)
234222
}
235223
}
@@ -243,7 +231,7 @@ func TestValidateAcceptsNonChatPath(t *testing.T) {
243231
// The inference:api scope authorizes every inference endpoint, not just
244232
// chat completions, so a non-chat path must validate.
245233
token := mintToken(t, priv, testKID, "at+jwt", validClaims(time.Now()), RequiredScope)
246-
if err := validateErr(v, key.Request{APIKey: token, Path: "/v1/embeddings"}); err != nil {
234+
if err := v.Validate(key.Request{APIKey: token, Path: "/v1/embeddings"}); err != nil {
247235
t.Fatalf("expected non-chat path to be accepted, got %v", err)
248236
}
249237
}
@@ -322,7 +310,7 @@ func TestValidateRefreshesUnknownKidAfterRecentSuccess(t *testing.T) {
322310
useSecond.Store(true)
323311

324312
token := mintToken(t, secondPriv, "test-key-2", "at+jwt", validClaims(time.Now()), RequiredScope)
325-
if err := validateErr(v, chatRequest(token)); err != nil {
313+
if err := v.Validate(chatRequest(token)); err != nil {
326314
t.Fatalf("expected unknown kid to refresh immediately, got %v", err)
327315
}
328316
}
@@ -356,7 +344,7 @@ func TestNewValidatorRecoversWhenJWKSStartsUnavailable(t *testing.T) {
356344
v := newTestValidator(t, srv.URL)
357345

358346
token := mintToken(t, priv, testKID, "at+jwt", validClaims(time.Now()), RequiredScope)
359-
if err := validateErr(v, chatRequest(token)); err == nil {
347+
if err := v.Validate(chatRequest(token)); err == nil {
360348
t.Fatal("expected rejection while no signing keys are cached")
361349
}
362350

@@ -367,7 +355,7 @@ func TestNewValidatorRecoversWhenJWKSStartsUnavailable(t *testing.T) {
367355
v.keys.lastAttempt = time.Now().Add(-2 * minRefreshInterval)
368356
v.keys.mu.Unlock()
369357

370-
if err := validateErr(v, chatRequest(token)); err != nil {
358+
if err := v.Validate(chatRequest(token)); err != nil {
371359
t.Fatalf("expected token to validate after JWKS became available, got %v", err)
372360
}
373361
}

tinfoil/internal/key/key.go

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,8 @@ type Request struct {
1414
Path string `json:"path,omitempty"`
1515
}
1616

17-
// Result is the non-secret outcome of a successful validation. It lets the shim
18-
// forward the authenticated principal to the upstream workload so the workload
19-
// can attribute usage and rate limits to a stable identity instead of the
20-
// (rotating) bearer credential.
21-
type Result struct {
22-
// Subject is the authenticated principal — the JWT `sub` — when the
23-
// credential is a locally-verified JWT access token. It is empty for opaque
24-
// API keys, whose subject is known only to the control plane.
25-
Subject string
26-
}
27-
2817
type Validator interface {
29-
Validate(req Request) (Result, error)
18+
Validate(req Request) error
3019
}
3120

3221
// ErrUnsupportedToken signals that a Validator cannot handle the presented
@@ -55,14 +44,13 @@ func NewChain(validators ...Validator) *Chain {
5544
return &Chain{validators: validators}
5645
}
5746

58-
func (c *Chain) Validate(req Request) (Result, error) {
59-
var res Result
47+
func (c *Chain) Validate(req Request) error {
6048
var err error
6149
for _, v := range c.validators {
62-
res, err = v.Validate(req)
50+
err = v.Validate(req)
6351
if !errors.Is(err, ErrUnsupportedToken) {
64-
return res, err
52+
return err
6553
}
6654
}
67-
return res, err
55+
return err
6856
}

tinfoil/internal/key/key_test.go

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@ import (
77
)
88

99
type stubValidator struct {
10-
result Result
1110
err error
1211
called *int
1312
}
1413

15-
func (s stubValidator) Validate(Request) (Result, error) {
14+
func (s stubValidator) Validate(Request) error {
1615
if s.called != nil {
1716
*s.called++
1817
}
19-
return s.result, s.err
18+
return s.err
2019
}
2120

2221
func TestChainFallsThroughOnUnsupported(t *testing.T) {
@@ -25,7 +24,7 @@ func TestChainFallsThroughOnUnsupported(t *testing.T) {
2524
stubValidator{err: ErrUnsupportedToken, called: &firstCalls},
2625
stubValidator{err: nil, called: &secondCalls},
2726
)
28-
if _, err := chain.Validate(Request{APIKey: "x"}); err != nil {
27+
if err := chain.Validate(Request{APIKey: "x"}); err != nil {
2928
t.Fatalf("expected success, got %v", err)
3029
}
3130
if firstCalls != 1 || secondCalls != 1 {
@@ -39,7 +38,7 @@ func TestChainReturnsValidationErrorImmediately(t *testing.T) {
3938
stubValidator{err: &ValidationError{StatusCode: http.StatusUnauthorized}},
4039
stubValidator{err: nil, called: &secondCalls},
4140
)
42-
_, err := chain.Validate(Request{APIKey: "x"})
41+
err := chain.Validate(Request{APIKey: "x"})
4342
var ve *ValidationError
4443
if !errors.As(err, &ve) || ve.StatusCode != http.StatusUnauthorized {
4544
t.Fatalf("expected 401 ValidationError, got %v", err)
@@ -55,27 +54,10 @@ func TestChainShortCircuitsOnSuccess(t *testing.T) {
5554
stubValidator{err: nil},
5655
stubValidator{err: ErrUnsupportedToken, called: &secondCalls},
5756
)
58-
if _, err := chain.Validate(Request{APIKey: "x"}); err != nil {
57+
if err := chain.Validate(Request{APIKey: "x"}); err != nil {
5958
t.Fatalf("expected success, got %v", err)
6059
}
6160
if secondCalls != 0 {
6261
t.Fatalf("second validator should not be consulted, calls=%d", secondCalls)
6362
}
6463
}
65-
66-
// TestChainPropagatesSubject verifies the Result from the validator that
67-
// handled the request (here the second, after the first falls through) is
68-
// returned to the caller, so the shim can forward the verified subject.
69-
func TestChainPropagatesSubject(t *testing.T) {
70-
chain := NewChain(
71-
stubValidator{err: ErrUnsupportedToken},
72-
stubValidator{result: Result{Subject: "user_42"}, err: nil},
73-
)
74-
res, err := chain.Validate(Request{APIKey: "x"})
75-
if err != nil {
76-
t.Fatalf("expected success, got %v", err)
77-
}
78-
if res.Subject != "user_42" {
79-
t.Fatalf("Subject = %q, want user_42", res.Subject)
80-
}
81-
}

0 commit comments

Comments
 (0)