Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions service/kas/access/accessPdp.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ var decryptAction = &policy.Action{
}

type PDPAccessResult struct {
Access bool
Error error
Policy *Policy
Access bool
Error error
Policy *Policy
RequiredObligations []string
}

func (p *Provider) canAccess(ctx context.Context, token *entity.Token, policies []*Policy) ([]PDPAccessResult, error) {
func (p *Provider) canAccess(ctx context.Context, token *entity.Token, policies []*Policy, fulfillableObligationFQNs []string) ([]PDPAccessResult, error) {
var res []PDPAccessResult
var resources []*authzV2.Resource
idPolicyMap := make(map[string]*Policy)
Expand Down Expand Up @@ -67,7 +68,7 @@ func (p *Provider) canAccess(ctx context.Context, token *entity.Token, policies
ctx, span := p.Start(ctx, "checkAttributes")
defer span.End()

resourceDecisions, err := p.checkAttributes(ctx, resources, token)
resourceDecisions, err := p.checkAttributes(ctx, resources, token, fulfillableObligationFQNs)
if err != nil {
return nil, err
}
Expand All @@ -78,14 +79,14 @@ func (p *Provider) canAccess(ctx context.Context, token *entity.Token, policies
p.Logger.WarnContext(ctx, "unexpected ephemeral resource id not mapped to a policy")
continue
}
res = append(res, PDPAccessResult{Policy: policy, Access: decision.GetDecision() == authzV2.Decision_DECISION_PERMIT})
res = append(res, PDPAccessResult{Policy: policy, Access: decision.GetDecision() == authzV2.Decision_DECISION_PERMIT, RequiredObligations: decision.GetRequiredObligations()})
}

return res, nil
}

// checkAttributes makes authorization service GetDecision requests to check access to resources
func (p *Provider) checkAttributes(ctx context.Context, resources []*authzV2.Resource, ent *entity.Token) ([]*authzV2.ResourceDecision, error) {
func (p *Provider) checkAttributes(ctx context.Context, resources []*authzV2.Resource, ent *entity.Token, fulfillableObligationFQNs []string) ([]*authzV2.ResourceDecision, error) {
ctx = tracing.InjectTraceContext(ctx)

// If only one resource, prefer singular endpoint
Expand All @@ -94,8 +95,9 @@ func (p *Provider) checkAttributes(ctx context.Context, resources []*authzV2.Res
EntityIdentifier: &authzV2.EntityIdentifier{
Identifier: &authzV2.EntityIdentifier_Token{Token: ent},
},
Action: decryptAction,
Resource: resources[0],
Action: decryptAction,
Resource: resources[0],
FulfillableObligationFqns: fulfillableObligationFQNs,
}
dr, err := p.SDK.AuthorizationV2.GetDecision(ctx, req)
if err != nil {
Expand All @@ -110,8 +112,9 @@ func (p *Provider) checkAttributes(ctx context.Context, resources []*authzV2.Res
EntityIdentifier: &authzV2.EntityIdentifier{
Identifier: &authzV2.EntityIdentifier_Token{Token: ent},
},
Action: decryptAction,
Resources: resources,
Action: decryptAction,
Resources: resources,
FulfillableObligationFqns: fulfillableObligationFQNs,
}

dr, err := p.SDK.AuthorizationV2.GetDecisionMultiResource(ctx, req)
Expand Down
166 changes: 145 additions & 21 deletions service/kas/access/rewrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ import (
"fmt"
"log/slog"
"net/http"
"strings"
"time"

"connectrpc.com/connect"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/lib/identifier"
"github.com/opentdf/platform/lib/ocrypto"
"github.com/opentdf/platform/protocol/go/entity"
kaspb "github.com/opentdf/platform/protocol/go/kas"
Expand All @@ -35,13 +37,21 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"
)

const (
kTDF3Algorithm = "rsa:2048"
kNanoAlgorithm = "ec:secp256r1"
kFailedStatus = "fail"
kPermitStatus = "permit"
kTDF3Algorithm = "rsa:2048"
kNanoAlgorithm = "ec:secp256r1"
kFailedStatus = "fail"
kPermitStatus = "permit"
additionalRewrapContextHeader = "X-Rewrap-Additional-Context"
requiredObligationsHeader = "X-Required-Obligations"
)

var (
ErrDecodingRewrapContext = errors.New("failed to decode additional rewrap context")
ErrUnmarshalingRewrapContext = errors.New("failed to unmarshal additional rewrap context")
)

type SignedRequestBody struct {
Expand Down Expand Up @@ -74,8 +84,21 @@ type kaoResult struct {
EphemeralPublicKey []byte
}

type policyResult struct {
kaoResults map[string]kaoResult
requiredObligations []string
}

// From policy ID to KAO ID to result
type policyKAOResults map[string]map[string]kaoResult
type policyKAOResults map[string]policyResult

type ObligationCtx struct {
FulfillableFQNs []string `json:"fulfillableFQNs,omitempty"`
}

type AdditionalRewrapContext struct {
Obligations ObligationCtx `json:"obligations"`
}

const (
kNanoTDFGMACLength = 8
Expand Down Expand Up @@ -367,7 +390,7 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult
policyResults := &kaspb.PolicyRewrapResult{
PolicyId: policyID,
}
for kaoID, kaoRes := range policyMap {
for kaoID, kaoRes := range policyMap.kaoResults {
kaoResult := &kaspb.KeyAccessRewrapResult{
KeyAccessObjectId: kaoID,
}
Expand All @@ -385,6 +408,7 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult
policyResults.Results = append(policyResults.Results, kaoResult)
}
response.Responses = append(response.Responses, policyResults)
populateRequiredObligationsOnResponse(response, policyMap.requiredObligations, policyID)
}
}

Expand Down Expand Up @@ -428,11 +452,16 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap
}
}
var results policyKAOResults
additionalRewrapContext, err := getAdditionalRewrapContext(req.Header())
if err != nil {
p.Logger.WarnContext(ctx, "failed to get additional rewrap context", slog.Any("error", err))
return nil, err400(err.Error())
}
if len(tdf3Reqs) > 0 {
resp.SessionPublicKey, results = p.tdf3Rewrap(ctx, tdf3Reqs, body.GetClientPublicKey(), entityInfo)
resp.SessionPublicKey, results = p.tdf3Rewrap(ctx, tdf3Reqs, body.GetClientPublicKey(), entityInfo, additionalRewrapContext)
addResultsToResponse(resp, results)
} else {
resp.SessionPublicKey, results = p.nanoTDFRewrap(ctx, nanoReqs, body.GetClientPublicKey(), entityInfo)
resp.SessionPublicKey, results = p.nanoTDFRewrap(ctx, nanoReqs, body.GetClientPublicKey(), entityInfo, additionalRewrapContext)
addResultsToResponse(resp, results)
}

Expand All @@ -441,16 +470,16 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap
p.Logger.WarnContext(ctx, "status 400 due to wrong result set size", slog.Any("results", results))
return nil, err400("invalid request")
}
kaoResults := *getMapValue(results)
if len(kaoResults) != 1 {
policyResults := *getMapValue(results)
if len(policyResults.kaoResults) != 1 {
p.Logger.WarnContext(ctx,
"status 400 due to wrong result set size",
slog.Any("kao_results", kaoResults),
slog.Any("kao_results", policyResults.kaoResults),
slog.Any("results", results),
)
return nil, err400("invalid request")
}
kao := *getMapValue(kaoResults)
kao := *getMapValue(policyResults.kaoResults)

if kao.Error != nil {
p.Logger.DebugContext(ctx, "forwarding legacy err", slog.Any("error", kao.Error))
Expand Down Expand Up @@ -665,7 +694,7 @@ func (p *Provider) listLegacyKeys(ctx context.Context) []trust.KeyIdentifier {
return kidsToCheck
}

func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo) (string, policyKAOResults) {
func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo, additionalRewrapContext *AdditionalRewrapContext) (string, policyKAOResults) {
if p.Tracer != nil {
var span trace.Span
ctx, span = p.Start(ctx, "rewrap-tdf3")
Expand All @@ -678,7 +707,10 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
for _, req := range requests {
policy, kaoResults, err := p.verifyRewrapRequests(ctx, req)
policyID := req.GetPolicy().GetId()
results[policyID] = kaoResults
results[policyID] = policyResult{
kaoResults: kaoResults,
requiredObligations: []string{},
}
if err != nil {
p.Logger.WarnContext(ctx,
"rewrap: verifyRewrapRequests failed",
Expand All @@ -695,7 +727,8 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
EphemeralId: "rewrap-token",
Jwt: entityInfo.Token,
}
pdpAccessResults, accessErr := p.canAccess(ctx, tok, policies)

pdpAccessResults, accessErr := p.canAccess(ctx, tok, policies, additionalRewrapContext.Obligations.FulfillableFQNs)
if accessErr != nil {
p.Logger.DebugContext(ctx,
"tdf3rewrap: cannot access policy",
Expand Down Expand Up @@ -738,12 +771,14 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
p.Logger.WarnContext(ctx, "policy not found in policyReqs", "policy.uuid", policy.UUID)
continue
}
kaoResults, ok := results[req.GetPolicy().GetId()]

policyRes, ok := results[req.GetPolicy().GetId()]
if !ok { // this should not happen
//nolint:sloglint // reference to key is intentional
p.Logger.WarnContext(ctx, "policy not found in policyReq response", "policy.uuid", policy.UUID)
continue
}
kaoResults := policyRes.kaoResults
access := pdpAccess.Access

// Audit the TDF3 Rewrap
Expand Down Expand Up @@ -788,11 +823,15 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew

p.Logger.Audit.RewrapSuccess(ctx, auditEventParams)
}
results[req.GetPolicy().GetId()] = policyResult{
kaoResults: kaoResults,
requiredObligations: pdpAccess.RequiredObligations,
}
}
return sessionKey, results
}

func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo) (string, policyKAOResults) {
func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo, additionalRewrapContext *AdditionalRewrapContext) (string, policyKAOResults) {
ctx, span := p.Start(ctx, "nanoTDFRewrap")
defer span.End()

Expand All @@ -803,7 +842,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned

for _, req := range requests {
policy, kaoResults := p.verifyNanoRewrapRequests(ctx, req)
results[req.GetPolicy().GetId()] = kaoResults
results[req.GetPolicy().GetId()] = policyResult{kaoResults: kaoResults, requiredObligations: []string{}}
if policy != nil {
policies = append(policies, policy)
policyReqs[policy] = req
Expand All @@ -815,7 +854,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
Jwt: entityInfo.Token,
}

pdpAccessResults, accessErr := p.canAccess(ctx, tok, policies)
pdpAccessResults, accessErr := p.canAccess(ctx, tok, policies, additionalRewrapContext.Obligations.FulfillableFQNs)
if accessErr != nil {
failAllKaos(requests, results, err500("could not perform access"))
return "", results
Expand All @@ -840,7 +879,13 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
if !ok { // this should not happen
continue
}
kaoResults := results[req.GetPolicy().GetId()]
policyRes, ok := results[req.GetPolicy().GetId()]
if !ok { // this should not happen
//nolint:sloglint // reference to key is intentional
p.Logger.WarnContext(ctx, "policy not found in policyReq response", "policy.uuid", policy.UUID)
continue
}
kaoResults := policyRes.kaoResults
access := pdpAccess.Access

// Audit the Nano Rewrap
Expand Down Expand Up @@ -878,6 +923,10 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned

p.Logger.Audit.RewrapSuccess(ctx, auditEventParams)
}
results[req.GetPolicy().GetId()] = policyResult{
kaoResults: kaoResults,
requiredObligations: pdpAccess.RequiredObligations,
}
}
return sessionKeyPEM, results
}
Expand Down Expand Up @@ -996,7 +1045,82 @@ func extractNanoPolicy(symmetricKey ocrypto.ProtectedKey, header sdk.NanoTDFHead
func failAllKaos(reqs []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, results policyKAOResults, err error) {
for _, req := range reqs {
for _, kao := range req.GetKeyAccessObjects() {
failedKAORewrap(results[req.GetPolicy().GetId()], kao, err)
failedKAORewrap(results[req.GetPolicy().GetId()].kaoResults, kao, err)
}
}
}

// Populate response metadata with required obligations
func populateRequiredObligationsOnResponse(response *kaspb.RewrapResponse, obligations []string, policyID string) {
metadata := response.GetMetadata()
if metadata == nil {
metadata = make(map[string]*structpb.Value)
}

var fields map[string]*structpb.Value
obligationValue, ok := metadata[requiredObligationsHeader]
if !ok || obligationValue.GetStructValue() == nil {
fields = make(map[string]*structpb.Value)
metadata[requiredObligationsHeader] = structpb.NewStructValue(&structpb.Struct{
Fields: fields,
})
} else {
fields = obligationValue.GetStructValue().GetFields()
}

values := make([]*structpb.Value, len(obligations))
for i, obligation := range obligations {
values[i] = structpb.NewStringValue(obligation)
}
fields[policyID] = structpb.NewListValue(&structpb.ListValue{
Values: values,
})
response.Metadata = metadata
}

// Retrieve additional request context needed for rewrap processing
// Header is json encoded AdditionalRewrapContext struct
/*
Example:

{
"obligations": {"fulfillableFQNs": ["https://demo.com/obl/test/value/watermark","https://demo.com/obl/test/value/geofence"]}
}

*/
func getAdditionalRewrapContext(header http.Header) (*AdditionalRewrapContext, error) {
rewrapContext := &AdditionalRewrapContext{
Obligations: ObligationCtx{
FulfillableFQNs: []string{},
},
}
if header == nil {
return rewrapContext, nil
}
if val := header.Get(additionalRewrapContextHeader); val != "" {
decoded, err := base64.StdEncoding.DecodeString(val)
if err != nil {
return nil, errors.Join(ErrDecodingRewrapContext, err)
}

err = json.Unmarshal(decoded, rewrapContext)
if err != nil {
return nil, errors.Join(ErrUnmarshalingRewrapContext, err)
}

validObligations := make([]string, 0)
for _, r := range rewrapContext.Obligations.FulfillableFQNs {
normalizedObligation := strings.TrimSpace(r)
if len(normalizedObligation) == 0 {
continue
}
_, err = identifier.Parse[*identifier.FullyQualifiedObligation](normalizedObligation)
if err != nil {
return nil, fmt.Errorf("%w, for obligation %s", err, normalizedObligation)
}
validObligations = append(validObligations, normalizedObligation)
}
rewrapContext.Obligations.FulfillableFQNs = validObligations
}
return rewrapContext, nil
}
Loading
Loading