diff --git a/service/kas/access/accessPdp.go b/service/kas/access/accessPdp.go index 0c9359bc5..1253bfe76 100644 --- a/service/kas/access/accessPdp.go +++ b/service/kas/access/accessPdp.go @@ -23,12 +23,13 @@ var decryptAction = &policy.Action{ } type PDPAccessResult struct { - Access bool - Error error - Policy *Policy + Access bool + Error error + Policy *Policy + RequiredObligations []string } -func (p *Provider) canAccess(ctx context.Context, token *entity.Token, policies []*Policy) ([]PDPAccessResult, error) { +func (p *Provider) canAccess(ctx context.Context, token *entity.Token, policies []*Policy, fulfillableObligationFQNs []string) ([]PDPAccessResult, error) { var res []PDPAccessResult var resources []*authzV2.Resource idPolicyMap := make(map[string]*Policy) @@ -67,7 +68,7 @@ func (p *Provider) canAccess(ctx context.Context, token *entity.Token, policies ctx, span := p.Start(ctx, "checkAttributes") defer span.End() - resourceDecisions, err := p.checkAttributes(ctx, resources, token) + resourceDecisions, err := p.checkAttributes(ctx, resources, token, fulfillableObligationFQNs) if err != nil { return nil, err } @@ -78,14 +79,14 @@ func (p *Provider) canAccess(ctx context.Context, token *entity.Token, policies p.Logger.WarnContext(ctx, "unexpected ephemeral resource id not mapped to a policy") continue } - res = append(res, PDPAccessResult{Policy: policy, Access: decision.GetDecision() == authzV2.Decision_DECISION_PERMIT}) + res = append(res, PDPAccessResult{Policy: policy, Access: decision.GetDecision() == authzV2.Decision_DECISION_PERMIT, RequiredObligations: decision.GetRequiredObligations()}) } return res, nil } // checkAttributes makes authorization service GetDecision requests to check access to resources -func (p *Provider) checkAttributes(ctx context.Context, resources []*authzV2.Resource, ent *entity.Token) ([]*authzV2.ResourceDecision, error) { +func (p *Provider) checkAttributes(ctx context.Context, resources []*authzV2.Resource, ent *entity.Token, fulfillableObligationFQNs []string) ([]*authzV2.ResourceDecision, error) { ctx = tracing.InjectTraceContext(ctx) // If only one resource, prefer singular endpoint @@ -94,8 +95,9 @@ func (p *Provider) checkAttributes(ctx context.Context, resources []*authzV2.Res EntityIdentifier: &authzV2.EntityIdentifier{ Identifier: &authzV2.EntityIdentifier_Token{Token: ent}, }, - Action: decryptAction, - Resource: resources[0], + Action: decryptAction, + Resource: resources[0], + FulfillableObligationFqns: fulfillableObligationFQNs, } dr, err := p.SDK.AuthorizationV2.GetDecision(ctx, req) if err != nil { @@ -110,8 +112,9 @@ func (p *Provider) checkAttributes(ctx context.Context, resources []*authzV2.Res EntityIdentifier: &authzV2.EntityIdentifier{ Identifier: &authzV2.EntityIdentifier_Token{Token: ent}, }, - Action: decryptAction, - Resources: resources, + Action: decryptAction, + Resources: resources, + FulfillableObligationFqns: fulfillableObligationFQNs, } dr, err := p.SDK.AuthorizationV2.GetDecisionMultiResource(ctx, req) diff --git a/service/kas/access/rewrap.go b/service/kas/access/rewrap.go index 2f866c68a..2fd9bf5d6 100644 --- a/service/kas/access/rewrap.go +++ b/service/kas/access/rewrap.go @@ -16,12 +16,14 @@ import ( "fmt" "log/slog" "net/http" + "strings" "time" "connectrpc.com/connect" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/opentdf/platform/lib/identifier" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/protocol/go/entity" kaspb "github.com/opentdf/platform/protocol/go/kas" @@ -35,13 +37,21 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/structpb" ) const ( - kTDF3Algorithm = "rsa:2048" - kNanoAlgorithm = "ec:secp256r1" - kFailedStatus = "fail" - kPermitStatus = "permit" + kTDF3Algorithm = "rsa:2048" + kNanoAlgorithm = "ec:secp256r1" + kFailedStatus = "fail" + kPermitStatus = "permit" + additionalRewrapContextHeader = "X-Rewrap-Additional-Context" + requiredObligationsHeader = "X-Required-Obligations" +) + +var ( + ErrDecodingRewrapContext = errors.New("failed to decode additional rewrap context") + ErrUnmarshalingRewrapContext = errors.New("failed to unmarshal additional rewrap context") ) type SignedRequestBody struct { @@ -74,8 +84,21 @@ type kaoResult struct { EphemeralPublicKey []byte } +type policyResult struct { + kaoResults map[string]kaoResult + requiredObligations []string +} + // From policy ID to KAO ID to result -type policyKAOResults map[string]map[string]kaoResult +type policyKAOResults map[string]policyResult + +type ObligationCtx struct { + FulfillableFQNs []string `json:"fulfillableFQNs,omitempty"` +} + +type AdditionalRewrapContext struct { + Obligations ObligationCtx `json:"obligations"` +} const ( kNanoTDFGMACLength = 8 @@ -367,7 +390,7 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult policyResults := &kaspb.PolicyRewrapResult{ PolicyId: policyID, } - for kaoID, kaoRes := range policyMap { + for kaoID, kaoRes := range policyMap.kaoResults { kaoResult := &kaspb.KeyAccessRewrapResult{ KeyAccessObjectId: kaoID, } @@ -385,6 +408,7 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult policyResults.Results = append(policyResults.Results, kaoResult) } response.Responses = append(response.Responses, policyResults) + populateRequiredObligationsOnResponse(response, policyMap.requiredObligations, policyID) } } @@ -428,11 +452,16 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap } } var results policyKAOResults + additionalRewrapContext, err := getAdditionalRewrapContext(req.Header()) + if err != nil { + p.Logger.WarnContext(ctx, "failed to get additional rewrap context", slog.Any("error", err)) + return nil, err400(err.Error()) + } if len(tdf3Reqs) > 0 { - resp.SessionPublicKey, results = p.tdf3Rewrap(ctx, tdf3Reqs, body.GetClientPublicKey(), entityInfo) + resp.SessionPublicKey, results = p.tdf3Rewrap(ctx, tdf3Reqs, body.GetClientPublicKey(), entityInfo, additionalRewrapContext) addResultsToResponse(resp, results) } else { - resp.SessionPublicKey, results = p.nanoTDFRewrap(ctx, nanoReqs, body.GetClientPublicKey(), entityInfo) + resp.SessionPublicKey, results = p.nanoTDFRewrap(ctx, nanoReqs, body.GetClientPublicKey(), entityInfo, additionalRewrapContext) addResultsToResponse(resp, results) } @@ -441,16 +470,16 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap p.Logger.WarnContext(ctx, "status 400 due to wrong result set size", slog.Any("results", results)) return nil, err400("invalid request") } - kaoResults := *getMapValue(results) - if len(kaoResults) != 1 { + policyResults := *getMapValue(results) + if len(policyResults.kaoResults) != 1 { p.Logger.WarnContext(ctx, "status 400 due to wrong result set size", - slog.Any("kao_results", kaoResults), + slog.Any("kao_results", policyResults.kaoResults), slog.Any("results", results), ) return nil, err400("invalid request") } - kao := *getMapValue(kaoResults) + kao := *getMapValue(policyResults.kaoResults) if kao.Error != nil { p.Logger.DebugContext(ctx, "forwarding legacy err", slog.Any("error", kao.Error)) @@ -665,7 +694,7 @@ func (p *Provider) listLegacyKeys(ctx context.Context) []trust.KeyIdentifier { return kidsToCheck } -func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo) (string, policyKAOResults) { +func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo, additionalRewrapContext *AdditionalRewrapContext) (string, policyKAOResults) { if p.Tracer != nil { var span trace.Span ctx, span = p.Start(ctx, "rewrap-tdf3") @@ -678,7 +707,10 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew for _, req := range requests { policy, kaoResults, err := p.verifyRewrapRequests(ctx, req) policyID := req.GetPolicy().GetId() - results[policyID] = kaoResults + results[policyID] = policyResult{ + kaoResults: kaoResults, + requiredObligations: []string{}, + } if err != nil { p.Logger.WarnContext(ctx, "rewrap: verifyRewrapRequests failed", @@ -695,7 +727,8 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew EphemeralId: "rewrap-token", Jwt: entityInfo.Token, } - pdpAccessResults, accessErr := p.canAccess(ctx, tok, policies) + + pdpAccessResults, accessErr := p.canAccess(ctx, tok, policies, additionalRewrapContext.Obligations.FulfillableFQNs) if accessErr != nil { p.Logger.DebugContext(ctx, "tdf3rewrap: cannot access policy", @@ -738,12 +771,14 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew p.Logger.WarnContext(ctx, "policy not found in policyReqs", "policy.uuid", policy.UUID) continue } - kaoResults, ok := results[req.GetPolicy().GetId()] + + policyRes, ok := results[req.GetPolicy().GetId()] if !ok { // this should not happen //nolint:sloglint // reference to key is intentional p.Logger.WarnContext(ctx, "policy not found in policyReq response", "policy.uuid", policy.UUID) continue } + kaoResults := policyRes.kaoResults access := pdpAccess.Access // Audit the TDF3 Rewrap @@ -788,11 +823,15 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew p.Logger.Audit.RewrapSuccess(ctx, auditEventParams) } + results[req.GetPolicy().GetId()] = policyResult{ + kaoResults: kaoResults, + requiredObligations: pdpAccess.RequiredObligations, + } } return sessionKey, results } -func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo) (string, policyKAOResults) { +func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo, additionalRewrapContext *AdditionalRewrapContext) (string, policyKAOResults) { ctx, span := p.Start(ctx, "nanoTDFRewrap") defer span.End() @@ -803,7 +842,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned for _, req := range requests { policy, kaoResults := p.verifyNanoRewrapRequests(ctx, req) - results[req.GetPolicy().GetId()] = kaoResults + results[req.GetPolicy().GetId()] = policyResult{kaoResults: kaoResults, requiredObligations: []string{}} if policy != nil { policies = append(policies, policy) policyReqs[policy] = req @@ -815,7 +854,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned Jwt: entityInfo.Token, } - pdpAccessResults, accessErr := p.canAccess(ctx, tok, policies) + pdpAccessResults, accessErr := p.canAccess(ctx, tok, policies, additionalRewrapContext.Obligations.FulfillableFQNs) if accessErr != nil { failAllKaos(requests, results, err500("could not perform access")) return "", results @@ -840,7 +879,13 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned if !ok { // this should not happen continue } - kaoResults := results[req.GetPolicy().GetId()] + policyRes, ok := results[req.GetPolicy().GetId()] + if !ok { // this should not happen + //nolint:sloglint // reference to key is intentional + p.Logger.WarnContext(ctx, "policy not found in policyReq response", "policy.uuid", policy.UUID) + continue + } + kaoResults := policyRes.kaoResults access := pdpAccess.Access // Audit the Nano Rewrap @@ -878,6 +923,10 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned p.Logger.Audit.RewrapSuccess(ctx, auditEventParams) } + results[req.GetPolicy().GetId()] = policyResult{ + kaoResults: kaoResults, + requiredObligations: pdpAccess.RequiredObligations, + } } return sessionKeyPEM, results } @@ -996,7 +1045,82 @@ func extractNanoPolicy(symmetricKey ocrypto.ProtectedKey, header sdk.NanoTDFHead func failAllKaos(reqs []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, results policyKAOResults, err error) { for _, req := range reqs { for _, kao := range req.GetKeyAccessObjects() { - failedKAORewrap(results[req.GetPolicy().GetId()], kao, err) + failedKAORewrap(results[req.GetPolicy().GetId()].kaoResults, kao, err) + } + } +} + +// Populate response metadata with required obligations +func populateRequiredObligationsOnResponse(response *kaspb.RewrapResponse, obligations []string, policyID string) { + metadata := response.GetMetadata() + if metadata == nil { + metadata = make(map[string]*structpb.Value) + } + + var fields map[string]*structpb.Value + obligationValue, ok := metadata[requiredObligationsHeader] + if !ok || obligationValue.GetStructValue() == nil { + fields = make(map[string]*structpb.Value) + metadata[requiredObligationsHeader] = structpb.NewStructValue(&structpb.Struct{ + Fields: fields, + }) + } else { + fields = obligationValue.GetStructValue().GetFields() + } + + values := make([]*structpb.Value, len(obligations)) + for i, obligation := range obligations { + values[i] = structpb.NewStringValue(obligation) + } + fields[policyID] = structpb.NewListValue(&structpb.ListValue{ + Values: values, + }) + response.Metadata = metadata +} + +// Retrieve additional request context needed for rewrap processing +// Header is json encoded AdditionalRewrapContext struct +/* +Example: + +{ + "obligations": {"fulfillableFQNs": ["https://demo.com/obl/test/value/watermark","https://demo.com/obl/test/value/geofence"]} +} + +*/ +func getAdditionalRewrapContext(header http.Header) (*AdditionalRewrapContext, error) { + rewrapContext := &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{}, + }, + } + if header == nil { + return rewrapContext, nil + } + if val := header.Get(additionalRewrapContextHeader); val != "" { + decoded, err := base64.StdEncoding.DecodeString(val) + if err != nil { + return nil, errors.Join(ErrDecodingRewrapContext, err) + } + + err = json.Unmarshal(decoded, rewrapContext) + if err != nil { + return nil, errors.Join(ErrUnmarshalingRewrapContext, err) + } + + validObligations := make([]string, 0) + for _, r := range rewrapContext.Obligations.FulfillableFQNs { + normalizedObligation := strings.TrimSpace(r) + if len(normalizedObligation) == 0 { + continue + } + _, err = identifier.Parse[*identifier.FullyQualifiedObligation](normalizedObligation) + if err != nil { + return nil, fmt.Errorf("%w, for obligation %s", err, normalizedObligation) + } + validObligations = append(validObligations, normalizedObligation) } + rewrapContext.Obligations.FulfillableFQNs = validObligations } + return rewrapContext, nil } diff --git a/service/kas/access/rewrap_test.go b/service/kas/access/rewrap_test.go index 15f141ec1..d3873a067 100644 --- a/service/kas/access/rewrap_test.go +++ b/service/kas/access/rewrap_test.go @@ -14,11 +14,13 @@ import ( "testing" "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/structpb" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/opentdf/platform/lib/identifier" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/service/logger" ctxAuth "github.com/opentdf/platform/service/pkg/auth" @@ -551,3 +553,436 @@ func Test_GetEntityInfo_When_Authorization_MD_Invalid_Expect_Error(t *testing.T) require.Error(t, err) require.Contains(t, err.Error(), "missing") } + +func TestGetAdditionalRewrapContext(t *testing.T) { + tests := []struct { + name string + header http.Header + expectedResult *AdditionalRewrapContext + expectedError error + errorContains string + }{ + { + name: "nil header", + header: nil, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{}, + }, + }, + expectedError: nil, + }, + { + name: "empty header", + header: make(http.Header), + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{}, + }, + }, + expectedError: nil, + }, + { + name: "header without obligations", + header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{}, + }, + }, + expectedError: nil, + }, + { + name: "valid single watermark obligation", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{"obligations": {"fulfillableFQNs": ["https://demo.com/obl/test/value/watermark"]}}`))}, + }, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{"https://demo.com/obl/test/value/watermark"}, + }, + }, + expectedError: nil, + }, + { + name: "valid multiple obligations", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{"obligations": {"fulfillableFQNs": ["https://demo.com/obl/test/value/watermark","https://demo.com/obl/test/value/geofence"]}}`))}, + }, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{"https://demo.com/obl/test/value/watermark", "https://demo.com/obl/test/value/geofence"}, + }, + }, + expectedError: nil, + }, + { + name: "mixed valid and invalid fqns", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{"obligations": {"fulfillableFQNs": ["https://demo.com/obl/test/value/watermark","https://example.com/attr/Classification/value/restricted","https://virtru.com/obl/test/value/audit"]}}`))}, + }, + expectedResult: nil, + expectedError: identifier.ErrInvalidFQNFormat, + errorContains: "https://example.com/attr/Classification/value/restricted", + }, + { + name: "empty obligations array", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{"obligations": {"fulfillableFQNs": []}}`))}, + }, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{}, + }, + }, + expectedError: nil, + }, + { + name: "no fulfillableFQNs array", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{"obligations": {}}`))}, + }, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{}, + }, + }, + expectedError: nil, + }, + { + name: "no obligations array", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{}`))}, + }, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{}, + }, + }, + expectedError: nil, + }, + { + name: "obligations with empty values filtered out", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{"obligations": {"fulfillableFQNs": ["https://demo.com/obl/test/value/watermark","","https://demo.com/obl/test/value/geofence"]}}`))}, + }, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{"https://demo.com/obl/test/value/watermark", "https://demo.com/obl/test/value/geofence"}, + }, + }, + expectedError: nil, + }, + { + name: "obligations with whitespace trimmed", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{"obligations": {"fulfillableFQNs": [" https://demo.com/obl/test/value/watermark "," https://demo.com/obl/test/value/geofence "]}}`))}, + }, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{"https://demo.com/obl/test/value/watermark", "https://demo.com/obl/test/value/geofence"}, + }, + }, + expectedError: nil, + }, + { + name: "invalid FQN format obligation", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{"obligations": {"fulfillableFQNs": ["invalid-obligation-format"]}}`))}, + }, + expectedResult: nil, + expectedError: identifier.ErrInvalidFQNFormat, + errorContains: "invalid-obligation-format", + }, + { + name: "mixed invalid FQN format obligation", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{"obligations": {"fulfillableFQNs": ["https://demo.com/obl/test/value/watermark","invalid-obligation-format"]}}`))}, + }, + expectedResult: nil, + expectedError: identifier.ErrInvalidFQNFormat, + errorContains: "invalid-obligation-format", + }, + { + name: "invalid base64 encoding", + header: http.Header{ + additionalRewrapContextHeader: []string{`{"obligations": {"fulfillableFQNs": ["https://demo.com/obl/test/value/watermark","invalid-obligation-format"]}}`}, + }, + expectedResult: nil, + expectedError: ErrDecodingRewrapContext, + }, + { + name: "invalid JSON format", + header: http.Header{ + additionalRewrapContextHeader: []string{base64.StdEncoding.EncodeToString([]byte(`{invalid json}`))}, + }, + expectedResult: nil, + expectedError: ErrUnmarshalingRewrapContext, + }, + { + name: "empty base64 string", + header: http.Header{ + additionalRewrapContextHeader: []string{""}, + }, + expectedResult: &AdditionalRewrapContext{ + Obligations: ObligationCtx{ + FulfillableFQNs: []string{}, + }, + }, + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := getAdditionalRewrapContext(tt.header) + + if tt.expectedError != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectedError) + if tt.errorContains != "" { + require.ErrorContains(t, err, tt.errorContains) + } + require.Nil(t, result) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedResult, result) + } + }) + } +} + +func TestPopulateRequiredObligationsOnResponse(t *testing.T) { + type policyObligation struct { + obligations []string + policyID string + } + + tests := []struct { + name string + response *kaspb.RewrapResponse + policies []policyObligation + validate func(t *testing.T, response *kaspb.RewrapResponse) + }{ + { + name: "single policy with single obligation", + response: &kaspb.RewrapResponse{ + Metadata: make(map[string]*structpb.Value), + }, + policies: []policyObligation{ + { + obligations: []string{"https://demo.com/obl/test/value/watermark"}, + policyID: "policy1", + }, + }, + validate: func(t *testing.T, response *kaspb.RewrapResponse) { + metadata := response.GetMetadata() + require.Contains(t, metadata, requiredObligationsHeader) //nolint:staticcheck // testing deprecated field + + structValue := metadata[requiredObligationsHeader].GetStructValue() //nolint:staticcheck // testing deprecated field + require.NotNil(t, structValue) + require.Contains(t, structValue.GetFields(), "policy1") + + fields := structValue.GetFields() + listValue := fields["policy1"].GetListValue() + require.NotNil(t, listValue) + require.Len(t, listValue.GetValues(), 1) + assert.Equal(t, "https://demo.com/obl/test/value/watermark", listValue.GetValues()[0].GetStringValue()) + }, + }, + { + name: "single policy with multiple obligations", + response: &kaspb.RewrapResponse{ + Metadata: make(map[string]*structpb.Value), + }, + policies: []policyObligation{ + { + obligations: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + }, + policyID: "policy1", + }, + }, + validate: func(t *testing.T, response *kaspb.RewrapResponse) { + metadata := response.GetMetadata() + require.Contains(t, metadata, requiredObligationsHeader) + + structValue := metadata[requiredObligationsHeader].GetStructValue() + require.NotNil(t, structValue) + require.Contains(t, structValue.GetFields(), "policy1") + + fields := structValue.GetFields() + listValue := fields["policy1"].GetListValue() + require.NotNil(t, listValue) + require.Len(t, listValue.GetValues(), 2) + assert.Equal(t, "https://demo.com/obl/test/value/watermark", listValue.GetValues()[0].GetStringValue()) + assert.Equal(t, "https://demo.com/obl/test/value/geofence", listValue.GetValues()[1].GetStringValue()) + }, + }, + { + name: "multiple policies with different obligations", + response: &kaspb.RewrapResponse{ + Metadata: make(map[string]*structpb.Value), + }, + policies: []policyObligation{ + { + obligations: []string{"https://demo.com/obl/test/value/watermark"}, + policyID: "policy1", + }, + { + obligations: []string{"https://demo.com/obl/test/value/geofence"}, + policyID: "policy2", + }, + { + obligations: []string{"https://example.com/obl/test/value/mfa"}, + policyID: "policy3", + }, + }, + validate: func(t *testing.T, response *kaspb.RewrapResponse) { + metadata := response.GetMetadata() + require.Contains(t, metadata, requiredObligationsHeader) + + structValue := metadata[requiredObligationsHeader].GetStructValue() + require.NotNil(t, structValue) + fields := structValue.GetFields() + + // Verify policy1 + require.Contains(t, fields, "policy1") + listValue1 := fields["policy1"].GetListValue() + require.NotNil(t, listValue1) + require.Len(t, listValue1.GetValues(), 1) + require.Equal(t, "https://demo.com/obl/test/value/watermark", listValue1.GetValues()[0].GetStringValue()) + + // Verify policy2 + require.Contains(t, fields, "policy2") + listValue2 := fields["policy2"].GetListValue() + require.NotNil(t, listValue2) + require.Len(t, listValue2.GetValues(), 1) + require.Equal(t, "https://demo.com/obl/test/value/geofence", listValue2.GetValues()[0].GetStringValue()) + + // Verify policy3 + require.Contains(t, fields, "policy3") + listValue3 := fields["policy3"].GetListValue() + require.NotNil(t, listValue3) + require.Len(t, listValue3.GetValues(), 1) + require.Equal(t, "https://example.com/obl/test/value/mfa", listValue3.GetValues()[0].GetStringValue()) + }, + }, + { + name: "empty obligations list", + response: &kaspb.RewrapResponse{ + Metadata: make(map[string]*structpb.Value), + }, + policies: []policyObligation{ + { + obligations: []string{}, + policyID: "policy1", + }, + }, + validate: func(t *testing.T, response *kaspb.RewrapResponse) { + metadata := response.GetMetadata() + require.Contains(t, metadata, requiredObligationsHeader) //nolint:staticcheck // testing deprecated field + + structValue := metadata[requiredObligationsHeader].GetStructValue() //nolint:staticcheck // testing deprecated field + require.NotNil(t, structValue) + require.Contains(t, structValue.GetFields(), "policy1") + + fields := structValue.GetFields() + listValue := fields["policy1"].GetListValue() + require.NotNil(t, listValue) + require.Empty(t, listValue.GetValues()) + }, + }, + { + name: "nil response metadata", + response: &kaspb.RewrapResponse{ + Metadata: nil, + }, + policies: []policyObligation{ + { + obligations: []string{"https://demo.com/obl/test/value/watermark"}, + policyID: "policy1", + }, + }, + validate: func(t *testing.T, response *kaspb.RewrapResponse) { + metadata := response.GetMetadata() + require.NotNil(t, metadata) //nolint:staticcheck // testing deprecated field + require.Contains(t, metadata, requiredObligationsHeader) //nolint:staticcheck // testing deprecated field + + structValue := metadata[requiredObligationsHeader].GetStructValue() //nolint:staticcheck // testing deprecated field + require.NotNil(t, structValue) + require.Contains(t, structValue.GetFields(), "policy1") + + fields := structValue.GetFields() + listValue := fields["policy1"].GetListValue() + require.NotNil(t, listValue) + require.Len(t, listValue.GetValues(), 1) + require.Equal(t, "https://demo.com/obl/test/value/watermark", listValue.GetValues()[0].GetStringValue()) + }, + }, + { + name: "preserve existing metadata when adding obligations", + response: &kaspb.RewrapResponse{ + Metadata: map[string]*structpb.Value{ + "existing-header": structpb.NewStringValue("existing-value"), + "session-info": structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "sessionId": structpb.NewStringValue("session-123"), + "timestamp": structpb.NewNumberValue(1672531200), + }, + }), + }, + }, + policies: []policyObligation{ + { + obligations: []string{ + "https://demo.com/obl/test/value/watermark", + }, + policyID: "policy1", + }, + }, + validate: func(t *testing.T, response *kaspb.RewrapResponse) { + metadata := response.GetMetadata() + require.NotNil(t, metadata) //nolint:staticcheck // testing deprecated field + + // Verify existing metadata is preserved + require.Contains(t, metadata, "existing-header") + require.Equal(t, "existing-value", metadata["existing-header"].GetStringValue()) + + require.Contains(t, metadata, "session-info") + sessionInfo := metadata["session-info"].GetStructValue() + require.NotNil(t, sessionInfo) + sessionFields := sessionInfo.GetFields() + require.Contains(t, sessionFields, "sessionId") + require.Equal(t, "session-123", sessionFields["sessionId"].GetStringValue()) + require.Contains(t, sessionFields, "timestamp") + require.InDelta(t, float64(1672531200), sessionFields["timestamp"].GetNumberValue(), 0.001) + + // Verify new obligations are added + require.Contains(t, metadata, requiredObligationsHeader) + structValue := metadata[requiredObligationsHeader].GetStructValue() + require.NotNil(t, structValue) + require.Contains(t, structValue.GetFields(), "policy1") + + obligationFields := structValue.GetFields() + listValue := obligationFields["policy1"].GetListValue() + require.NotNil(t, listValue) + require.Len(t, listValue.GetValues(), 1) + require.Equal(t, "https://demo.com/obl/test/value/watermark", listValue.GetValues()[0].GetStringValue()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call populateRequiredObligationsOnResponse for each policy + for _, policy := range tt.policies { + populateRequiredObligationsOnResponse(tt.response, policy.obligations, policy.policyID) + } + tt.validate(t, tt.response) + }) + } +}