diff --git a/.env.example b/.env.example index 40c0976..d2b6c91 100644 --- a/.env.example +++ b/.env.example @@ -3,8 +3,14 @@ PRIVATE_KEY_BASE64= PUBLIC_KEY_BASE64= # Database -MYSQL_HOST=mysql -MYSQL_PORT=3306 -MYSQL_CASBIN_DATABASE=casbin -MYSQL_USER=admin -MYSQL_PASSWORD=secret +POSTGRES_HOST=postgres +POSTGRES_DB_DEMO=iam +POSTGRES_DB_TEST=iam_test +POSTGRES_USER=admin +POSTGRES_PASSWORD=secret +POSTGRES_SSL_MODE=disable + +# ABAC +POLICY_DIR=/usr/code/examples/abac/cmd/api/policies +JWT_ISSUER=https://abac.com +JWT_AUDIENCE=https://abac.com diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9fb9d38..98a1510 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,5 +11,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Run tests + + - name: Test core library run: make ci-test + + - name: Test examples + run: make ci-test-examples diff --git a/.golangci.yml b/.golangci.yml index 270e059..ef2ca7e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,58 +1,7 @@ -linters-settings: - dupl: - threshold: 100 - funlen: - lines: -1 - statements: 50 - goconst: - min-len: 2 - min-occurrences: 3 - gocritic: - enabled-tags: - - diagnostic - - experimental - - opinionated - - performance - - style - disabled-checks: - - dupImport - - ifElseChain - - octalLiteral - - whyNoLint - gocyclo: - min-complexity: 15 - gofmt: - rewrite-rules: - - pattern: 'interface{}' - replacement: 'any' - mnd: - # don't include the "operation" and "assign" - checks: - - argument - - case - - condition - - return - ignored-numbers: - - '0' - - '1' - - '2' - - '3' - ignored-functions: - - strings.SplitN - lll: - line-length: 140 - nolintlint: - allow-unused: false - require-explanation: true - require-specific: true - revive: - rules: - - name: unexported-return - disabled: true - - name: unused-parameter +version: "2" linters: - disable-all: true + default: none enable: - bodyclose - copyloopvar @@ -66,26 +15,95 @@ linters: - goconst - gocritic - gocyclo - - gofmt - - goimports - - mnd - goprintffuncname - gosec - - gosimple - govet - - intrange - ineffassign + - intrange - lll + - mnd - nakedret - noctx - nolintlint - revive - staticcheck - - stylecheck - unconvert - unparam - unused - whitespace + settings: + dupl: + threshold: 100 + funlen: + lines: -1 + statements: 50 + goconst: + min-len: 2 + min-occurrences: 3 + gocritic: + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + disabled-checks: + - dupImport + - ifElseChain + - octalLiteral + - whyNoLint + gocyclo: + min-complexity: 15 + lll: + line-length: 140 + mnd: + checks: + - argument + - case + - condition + - return + ignored-numbers: + - "0" + - "1" + - "2" + - "3" + ignored-functions: + - strings.SplitN + nolintlint: + require-explanation: true + require-specific: true + allow-unused: false + revive: + rules: + - name: indent-error-flow + - name: unexported-return + disabled: true + - name: unused-parameter + - name: unused-receiver + + exclusions: + generated: strict + warn-unused: true + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - path: "_test.go" + text: "unlambda" + linters: + - gocritic -run: - timeout: 5m +formatters: + enable: + - gofmt + - goimports + settings: + gofmt: + rewrite-rules: + - pattern: interface{} + replacement: any + exclusions: + generated: strict + warn-unused: true diff --git a/Makefile b/Makefile index 91fc76a..1c81bad 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ DIST_DIR := _dist -GO_CODE_DIR := cmd internal +GO_CODE_DIR := abac TEST_OUTPUT_DIR := ${DIST_DIR}/tests -BUILD_DIR := ${DIST_DIR}/build -DEFAULT_VERSION := v0.0.0 +EXAMPLES_DIR ?= examples +EXAMPLE_DIRS := $(shell find $(EXAMPLES_DIR) -mindepth 1 -maxdepth 1 -type d -print) # Docker .PHONY: up @@ -37,15 +37,15 @@ ci-%: create-dev-env .PHONY: test test: lint-actions lint-go test-go -.PHONY: build -build: cleanup-build test - @echo "Running api-casbin and api-opa in parallel..." - @$(MAKE) -j 2 api-casbin api-opa +.PHONY: test-examples +test-examples: + @for d in $(EXAMPLE_DIRS); do $(MAKE) -C "$$d" test; done ## App .PHONY: lint-go lint-go: @echo "Running Go linter on code in $(GO_CODE_DIR)..." + @golangci-lint fmt $(addsuffix /..., $(GO_CODE_DIR)) -v @golangci-lint run $(addsuffix /..., $(GO_CODE_DIR)) -v .PHONY: test-go @@ -64,20 +64,6 @@ test-go: $(addprefix `pwd`/, $(addsuffix /..., $(GO_CODE_DIR))) @go tool cover -html=${TEST_OUTPUT_DIR}/cp.out -o ${TEST_OUTPUT_DIR}/cp.html -.PHONY: api-% -api-%: - @CURRENT_VERSION=$(shell git describe --tags --abbrev=0 2>/dev/null || echo $(DEFAULT_VERSION)); \ - echo "Building api (version $$CURRENT_VERSION) using $*..."; \ - go build -o ${BUILD_DIR}/api-$* \ - -a -ldflags "-X 'github.com/CameronXie/access-control-explorer/internal/version.Version=$$CURRENT_VERSION' -extldflags '-s -w -static'" \ - -tags $* \ - ./cmd - -.PHONY: cleanup-build -cleanup-build: - @rm -rf ${BUILD_DIR} - @mkdir -p ${BUILD_DIR} - ## Action .PHONY: lint-actions lint-actions: diff --git a/README.md b/README.md index 74bc879..9d79c03 100644 --- a/README.md +++ b/README.md @@ -2,35 +2,126 @@ [![Test](https://github.com/CameronXie/access-control-explorer/actions/workflows/test.yaml/badge.svg)](https://github.com/CameronXie/access-control-explorer/actions/workflows/test.yaml) -AccessControlExplorer is a project designed to facilitate the testing and exploration of various access control -architectures and models. The objective is to perform a thorough evaluation of different access control mechanisms in -terms of their effectiveness, performance, and adaptability. +## Purpose -## Features +Access Control Explorer is designed to facilitate the exploration and implementation of modern access control +architectures. The project provides reusable libraries and practical examples to evaluate different access control +mechanisms in terms of their effectiveness, performance, and adaptability to real-world scenarios. -| Access Control Model | Description | Build Command | -|----------------------------------|--------------------------|-------------------| -| Role-Based Access Control (RBAC) | implemented using Casbin | `make api-casbin` | -| Role-Based Access Control (RBAC) | implemented using OPA | `make api-opa` | +The primary objective is to offer developers and security practitioners a comprehensive toolkit for understanding and +implementing sophisticated access control patterns, with emphasis on attribute-based access control (ABAC) and its +practical applications. -## API Endpoints +## Components -The API endpoints for the access control models are registered in the [`internal/api/api.go`](internal/api/rest/api.go) -file. +### ABAC Library -## Getting Started +The [`abac/`](abac/) directory contains a general-purpose ABAC library following XACML-style architecture: -1. Clone the repository and change into the directory. -2. This project uses Docker for the local development environment. To start the Docker container and generate the RSA - key pair for JWTs, run `make up`. This command also sets the necessary environment variables, such as - `PRIVATE_KEY_BASE64` and `PUBLIC_KEY_BASE64`, for the RSA keys. -3. Inside the Docker container, use the build commands from the [features](#features) table to compile the specific API. - For example: `make api-opa`. -4. Inside the Docker container, run the compiled API located in the `_dist/build` directory. For example: - `_dist/build/api-opa`. -5. Access the API via [http://localhost:8080](http://localhost:8080). +- **Decision Maker (Policy Decision Point)**: Policy decision maker with configurable policy resolvers +- **Policy Provider (Policy Retrieval Point)**: Policy provider with file-based storage support +- **Enforcer (Policy Enforcement Point)**: Enforcement interfaces and implementations +- **Request Orchestrator (Context Handler)**: Request orchestrator for enriching access requests with contextual attributes +- **Info Provider (Policy Information Point)**: Information provider for enriching requests with additional contextual data +- **Policy Evaluator**: Policy evaluation engine with OPA/Rego implementation for policy execution +- **Extensions**: Support for obligations, advices, and custom information providers -## Test +The library provides clean interfaces that can be extended with custom implementations for different deployment +scenarios and policy requirements. -To run the tests, simply execute the following command `make test`. This command will perform GitHub Actions linting, Go -code linting, and Go unit tests to ensure the code quality and functionality. +### Examples + +#### REST API with ABAC Enforcement + +The [`examples/abac/`](examples/abac/) directory demonstrates a complete implementation of ABAC enforcement in a REST +API context: + +- **E-commerce Use Case**: Order management system with role-based permissions implemented through ABAC +- **HTTP Middleware**: Enforcer (Policy Enforcement Point) as HTTP middleware +- **JWT Authentication**: Token-based authentication with RS256 signing and automatic user context enrichment +- **Policy Implementation**: Rego policies implementing RBAC patterns within ABAC framework +- **Obligations and Advices**: Practical examples of audit logging and caching hints + +For detailed setup and usage instructions, see the [ABAC Example README](examples/abac/README.md). + +## Development Setup + +This project uses Docker and Docker Compose for local development environment setup. + +### Prerequisites + +- Docker and Docker Compose +- Make + +### Local Environment + +Create and start the development environment: + +```shell +make up +``` + +This command: + +- Generates RSA key pairs for JWT signing/verification if not present +- Creates the necessary `.env` file from `.env.example` +- Starts all required services via Docker Compose + +Stop the development environment: + +```shell +make down +``` + +### Testing + +Run the complete test suite: + +```shell +make test +``` + +This includes: + +- GitHub Actions linting +- Go code linting and formatting +- Go unit tests with race detection and coverage analysis + +For Go-specific tests only: + +```shell +make test-go +``` + +Test all examples in the project: + +```shell +make test-examples +``` + +This command iterates through all example directories and runs their individual test suites, ensuring that all practical +implementations work correctly with the core ABAC library. + +Test artifacts are generated in `_dist/tests/` including coverage reports. + +### Code Quality + +Lint Go code: + +```shell +make lint-go +``` + +Lint GitHub Actions workflows: + +```shell +make lint-actions +``` + +## Contributing + +1. Ensure Docker and Make are installed +2. Run `make up` to set up the development environment +3. Make your changes +4. Run `make test` to verify all tests pass +5. Submit your pull request diff --git a/abac/decisionmaker/decisionmaker.go b/abac/decisionmaker/decisionmaker.go new file mode 100644 index 0000000..d2cbc81 --- /dev/null +++ b/abac/decisionmaker/decisionmaker.go @@ -0,0 +1,381 @@ +package decisionmaker + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/CameronXie/access-control-explorer/abac/policyprovider" + "github.com/google/uuid" + "golang.org/x/sync/errgroup" +) + +// Decision represents the possible outcomes of an authorization decision +type Decision string + +const ( + Permit Decision = "Permit" // Request is allowed + Deny Decision = "Deny" // Request is denied + Indeterminate Decision = "Indeterminate" // Errors prevented making a decision + NotApplicable Decision = "NotApplicable" // No applicable policy was found +) + +// UnmarshalJSON parses the JSON-encoded data and validates it as one of the defined Decision values. +func (d *Decision) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + switch Decision(s) { + case Permit, Deny, Indeterminate, NotApplicable: + *d = Decision(s) + return nil + default: + return fmt.Errorf("invalid decision value: %q, must be one of: Permit, Deny, Indeterminate, NotApplicable", s) + } +} + +// StatusCode represents the possible states of the decision +type StatusCode string + +const ( + StatusOK StatusCode = "OK" // Decision was successfully evaluated + StatusMissingAttribute StatusCode = "AttributeMissing" // A required attribute is missing + StatusProcessingError StatusCode = "ProcessingError" // An internal processing error occurred + StatusInvalidRequest StatusCode = "InvalidRequest" // The request is malformed + StatusPolicyNotFound StatusCode = "PolicyNotFound" // No matching policies were found + StatusEvaluationError StatusCode = "EvaluationError" // General evaluation error +) + +// UnmarshalJSON parses a JSON-encoded byte array and sets the StatusCode value if it matches a valid predefined status. +// Returns an error if the provided JSON does not represent a valid StatusCode. +func (sc *StatusCode) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + switch StatusCode(s) { + case StatusOK, StatusMissingAttribute, StatusProcessingError, StatusInvalidRequest, StatusPolicyNotFound, StatusEvaluationError: + *sc = StatusCode(s) + return nil + default: + return fmt.Errorf( + "invalid status value: %q, must be one of: %s", + s, + strings.Join([]string{ + string(StatusOK), + string(StatusMissingAttribute), + string(StatusProcessingError), + string(StatusInvalidRequest), + string(StatusPolicyNotFound), + string(StatusEvaluationError), + }, ", "), + ) + } +} + +// Subject represents the entity requesting access (user, service, etc.) +type Subject struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Attributes map[string]any `json:"attributes,omitempty"` +} + +// Resource represents the protected asset being accessed +type Resource struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Attributes map[string]any `json:"attributes,omitempty"` +} + +// Action represents the operation being performed on the resource +type Action struct { + ID string `json:"id"` + Attributes map[string]any `json:"attributes,omitempty"` +} + +// DecisionRequest represents an access decision request including the subject, resource, action, and environmental context. +type DecisionRequest struct { + RequestID uuid.UUID `json:"requestId"` + Subject Subject `json:"subject"` + Resource Resource `json:"resource"` + Action Action `json:"action"` + Environment map[string]any `json:"environment,omitempty"` +} + +// Obligation represents a mandatory action that must be performed when enforcing the decision +type Obligation struct { + ID string `json:"id"` + Attributes map[string]any `json:"attributes,omitempty"` +} + +// Advice represents a recommended but not mandatory action related to the decision +type Advice struct { + ID string `json:"id"` + Attributes map[string]any `json:"attributes,omitempty"` +} + +// Status provides detailed information about the outcome of the decision process +type Status struct { + Code StatusCode `json:"code"` + Message string `json:"message"` +} + +// DecisionResponse represents the result of evaluating an authorization request, including decisions, status, and obligations. +type DecisionResponse struct { + RequestID uuid.UUID `json:"requestId"` + Decision Decision `json:"decision"` + Status *Status `json:"status,omitempty"` + Obligations []Obligation `json:"obligations,omitempty"` + Advice []Advice `json:"advice,omitempty"` + EvaluatedAt time.Time `json:"evaluatedAt"` + PolicyIdReferences []PolicyIdReference `json:"policyIdReferences"` +} + +// Policy represents a retrieved policy that will be evaluated against a request +type Policy struct { + ID string + Version string + Content []byte +} + +// PolicyIdReference represents a reference to a policy, including its unique identifier and version information. +type PolicyIdReference struct { + ID string `json:"id"` + Version string `json:"version"` +} + +// PolicyResolver defines the interface for components that resolve policy references based on a decision request. +type PolicyResolver interface { + // Resolve analyzes a decision request and returns policy references that are applicable to the request. + Resolve(ctx context.Context, req *DecisionRequest) ([]PolicyIdReference, error) +} + +// DecisionMaker defines the interface for components that make authorization decisions +type DecisionMaker interface { + // MakeDecision evaluates a decision request using policies and returns an authorization decision or an error. + MakeDecision(ctx context.Context, req *DecisionRequest) (*DecisionResponse, error) +} + +// decisionMaker implements the DecisionMaker interface +type decisionMaker struct { + processors []PolicyResolver + provider policyprovider.PolicyProvider + evaluator PolicyEvaluator +} + +// Option defines configuration options for DecisionMaker +type Option func(*decisionMaker) + +// NewDecisionMaker creates a new DecisionMaker with the provided dependencies and options +func NewDecisionMaker(provider policyprovider.PolicyProvider, evaluator PolicyEvaluator, options ...Option) DecisionMaker { + dm := &decisionMaker{ + processors: make([]PolicyResolver, 0), + provider: provider, + evaluator: evaluator, + } + + for _, option := range options { + option(dm) + } + + return dm +} + +// WithPolicyResolver registers a policy resolver +func WithPolicyResolver(processor PolicyResolver) Option { + return func(dm *decisionMaker) { + dm.processors = append(dm.processors, processor) + } +} + +// MakeDecision evaluates the given decision request based on applicable policies and returns a decision response or an error. +func (d *decisionMaker) MakeDecision(ctx context.Context, req *DecisionRequest) (*DecisionResponse, error) { + if req == nil { + return nil, errors.New("decision request cannot be nil") + } + + // Resolve applicable policy references for this request + policyRefs, err := d.resolve(ctx, req) + if err != nil { + return &DecisionResponse{ + RequestID: req.RequestID, + Decision: Indeterminate, + Status: &Status{ + Code: StatusProcessingError, + Message: fmt.Sprintf("Failed to resolve policies: %v", err), + }, + EvaluatedAt: time.Now(), + }, nil + } + + if len(policyRefs) == 0 { + return &DecisionResponse{ + RequestID: req.RequestID, + Decision: NotApplicable, + Status: &Status{ + Code: StatusPolicyNotFound, + Message: "No applicable policies found for the request", + }, + EvaluatedAt: time.Now(), + }, nil + } + + // Retrieve policy contents + policies, err := d.getPolicies(ctx, policyRefs) + if err != nil { + return &DecisionResponse{ + RequestID: req.RequestID, + Decision: Indeterminate, + Status: &Status{ + Code: StatusProcessingError, + Message: fmt.Sprintf("Failed to retrieve policies: %v", err), + }, + EvaluatedAt: time.Now(), + PolicyIdReferences: policyRefs, + }, nil + } + + // Evaluate the request against policies + result, err := d.evaluator.Evaluate(ctx, req, policies) + if err != nil { + return &DecisionResponse{ + RequestID: req.RequestID, + Decision: Indeterminate, + Status: &Status{ + Code: StatusEvaluationError, + Message: fmt.Sprintf("Policy evaluation failed: %v", err), + }, + EvaluatedAt: time.Now(), + PolicyIdReferences: policyRefs, + }, nil + } + + // Build the final decision response with evaluation results, request metadata, and policy information + response := &DecisionResponse{ + RequestID: req.RequestID, + Decision: result.Decision, + Status: &result.Status, + Obligations: result.Obligations, + Advice: result.Advice, + EvaluatedAt: time.Now(), + PolicyIdReferences: policyRefs, + } + + return response, nil +} + +// resolve executes the resolution process for a decision request using configured processors and returns unique policy references. +func (d *decisionMaker) resolve(ctx context.Context, req *DecisionRequest) ([]PolicyIdReference, error) { + if len(d.processors) == 0 { + return nil, errors.New("no policy resolve processors configured") + } + + // Create an error group to manage parallel execution + g, ctx := errgroup.WithContext(ctx) + + // Use mutex to protect the shared map + var mu sync.Mutex + seen := make(map[string]PolicyIdReference) + + // Launch each processor in its own goroutine + for _, processor := range d.processors { + proc := processor + g.Go(func() error { + // Check if context was canceled before processing + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Resolve the request + results, err := proc.Resolve(ctx, req) + if err != nil { + return err + } + + // Skip lock if no results + if len(results) == 0 { + return nil + } + + // Safely add results to the shared map using policy ID as the key + mu.Lock() + defer mu.Unlock() + for _, policyRef := range results { + existingPolicyRef, exists := seen[policyRef.ID] + + if !exists { + seen[policyRef.ID] = policyRef + continue + } + + if existingPolicyRef.Version == policyRef.Version { + return fmt.Errorf( + "duplicate policy reference detected: policy '%s' version '%s' returned by multiple processors", + existingPolicyRef.ID, + existingPolicyRef.Version, + ) + } + + return fmt.Errorf("duplicate policy ID '%s' found: existing version '%s', conflicting version '%s'", + policyRef.ID, existingPolicyRef.Version, policyRef.Version) + } + + return nil + }) + } + + // Wait for all processors to complete or first error + if err := g.Wait(); err != nil { + return nil, err + } + + // Convert the map values to a slice + policyRefs := make([]PolicyIdReference, 0, len(seen)) + for _, policyRef := range seen { + policyRefs = append(policyRefs, policyRef) + } + + return policyRefs, nil +} + +// getPolicies retrieves policy content for a list of policy references +func (d *decisionMaker) getPolicies(ctx context.Context, policyRefs []PolicyIdReference) ([]Policy, error) { + if len(policyRefs) == 0 { + return nil, errors.New("no policy references provided") + } + + // Convert PolicyReference to PolicyRequests + policyRequests := make([]policyprovider.GetPolicyRequest, 0, len(policyRefs)) + for _, ref := range policyRefs { + policyRequests = append(policyRequests, policyprovider.GetPolicyRequest{ + ID: ref.ID, + Version: ref.Version, + }) + } + + // Request policies from provider + responses, err := d.provider.GetPolicies(ctx, policyRequests) + if err != nil { + return nil, fmt.Errorf("failed to retrieve policies: %w", err) + } + + // Resolve responses + policies := make([]Policy, 0, len(responses)) + for _, resp := range responses { + policies = append(policies, Policy{ + ID: resp.ID, + Version: resp.Version, + Content: resp.Content, + }) + } + + return policies, nil +} diff --git a/abac/decisionmaker/decisionmaker_test.go b/abac/decisionmaker/decisionmaker_test.go new file mode 100644 index 0000000..c48d86b --- /dev/null +++ b/abac/decisionmaker/decisionmaker_test.go @@ -0,0 +1,734 @@ +//nolint:lll // unit tests +package decisionmaker + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/CameronXie/access-control-explorer/abac/policyprovider" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// Mock implementations +type mockPolicyProvider struct { + mock.Mock +} + +func (m *mockPolicyProvider) GetPolicies( + ctx context.Context, + reqs []policyprovider.GetPolicyRequest, +) ([]policyprovider.PolicyResponse, error) { + args := m.Called(ctx, reqs) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]policyprovider.PolicyResponse), args.Error(1) +} + +type mockPolicyEvaluator struct { + mock.Mock +} + +func (m *mockPolicyEvaluator) Evaluate(ctx context.Context, req *DecisionRequest, policies []Policy) (*EvaluationResult, error) { + args := m.Called(ctx, req, policies) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*EvaluationResult), args.Error(1) +} + +type mockPolicyResolver struct { + mock.Mock + delay time.Duration +} + +func (m *mockPolicyResolver) Resolve(ctx context.Context, req *DecisionRequest) ([]PolicyIdReference, error) { + // Simulate processing delay if configured + if m.delay > 0 { + select { + case <-time.After(m.delay): + // Continue with processing after delay + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + args := m.Called(ctx, req) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]PolicyIdReference), args.Error(1) +} + +type mockPolicyResolverConfig struct { + delay time.Duration + policyIdRefs []PolicyIdReference + err error +} + +// TestDecisionMaker_MakeDecision tests the DecisionMaker's MakeDecision method +func TestDecisionMaker_MakeDecision(t *testing.T) { + // Common test variables + fixedUUID := uuid.New() + + // Standard request used across tests + standardRequest := &DecisionRequest{ + RequestID: fixedUUID, + Subject: Subject{ + ID: "user123", + Type: "user", + Attributes: map[string]any{ + "roles": []string{"admin", "user"}, + }, + }, + Resource: Resource{ + ID: "resource456", + Type: "document", + }, + Action: Action{ + ID: "read", + }, + } + + // Mock policy responses + policyResponses := []policyprovider.PolicyResponse{ + { + ID: "policy1", + Version: "1.0", + Content: []byte(`{"policy": "content1"}`), + }, + { + ID: "policy2", + Version: "1.0", + Content: []byte(`{"policy": "content2"}`), + }, + } + + // Define test cases + tests := map[string]struct { + request *DecisionRequest + policyResolverConfigs []*mockPolicyResolverConfig + policyResolveErr bool + policyProviderResponses []policyprovider.PolicyResponse + policyRetrievalFail error + evaluationResult *EvaluationResult + evaluatorError error + expectedResponse *DecisionResponse + expectedError string + }{ + "should return error when request is nil": { + request: nil, + expectedError: "decision request cannot be nil", + }, + + "should return indeterminate decision when no resolvers configured": { + request: standardRequest, + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: Indeterminate, + Status: &Status{ + Code: StatusProcessingError, + Message: "Failed to resolve policies: no policy resolve processors configured", + }, + EvaluatedAt: time.Now(), + }, + }, + + "should return indeterminate decision when resolver fails": { + request: standardRequest, + policyResolverConfigs: []*mockPolicyResolverConfig{ + { + policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "1.0"}}, + err: errors.New("resolver error"), + }, + }, + policyResolveErr: true, + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: Indeterminate, + Status: &Status{ + Code: StatusProcessingError, + Message: "Failed to resolve policies: resolver error", + }, + EvaluatedAt: time.Now(), + }, + }, + + "should return not applicable decision when no applicable policies found": { + request: standardRequest, + policyResolverConfigs: []*mockPolicyResolverConfig{ + {policyIdRefs: []PolicyIdReference{}}, + }, + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: NotApplicable, + Status: &Status{ + Code: StatusPolicyNotFound, + Message: "No applicable policies found for the request", + }, + EvaluatedAt: time.Now(), + }, + }, + + "should return indeterminate decision when policy retrieval fails": { + request: standardRequest, + policyResolverConfigs: []*mockPolicyResolverConfig{ + {policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "1.0"}}}, + {policyIdRefs: []PolicyIdReference{{ID: "policy2", Version: "1.0"}}}, + }, + policyRetrievalFail: errors.New("provider error"), + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: Indeterminate, + Status: &Status{ + Code: StatusProcessingError, + Message: "Failed to retrieve policies: failed to retrieve policies: provider error", + }, + EvaluatedAt: time.Now(), + PolicyIdReferences: []PolicyIdReference{ + {ID: "policy1", Version: "1.0"}, + {ID: "policy2", Version: "1.0"}, + }, + }, + }, + + "should return indeterminate decision when policy evaluation fails": { + request: standardRequest, + policyResolverConfigs: []*mockPolicyResolverConfig{ + {policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "1.0"}}}, + {policyIdRefs: []PolicyIdReference{{ID: "policy2", Version: "1.0"}}}, + }, + policyProviderResponses: policyResponses, + evaluatorError: errors.New("evaluator error"), + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: Indeterminate, + Status: &Status{ + Code: StatusEvaluationError, + Message: "Policy evaluation failed: evaluator error", + }, + EvaluatedAt: time.Now(), + PolicyIdReferences: []PolicyIdReference{ + {ID: "policy1", Version: "1.0"}, + {ID: "policy2", Version: "1.0"}, + }, + }, + }, + + "should return permit decision with obligations and advice": { + request: standardRequest, + policyResolverConfigs: []*mockPolicyResolverConfig{ + {policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "1.0"}, {ID: "policy2", Version: "1.0"}}}, + }, + policyProviderResponses: policyResponses, + evaluationResult: &EvaluationResult{ + Decision: Permit, + Status: Status{ + Code: StatusOK, + Message: "Access permitted", + }, + Obligations: []Obligation{ + { + ID: "log-access", + Attributes: map[string]any{ + "level": "info", + }, + }, + }, + Advice: []Advice{ + { + ID: "remind-confidentiality", + Attributes: map[string]any{ + "message": "This document is confidential", + }, + }, + }, + }, + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: Permit, + Status: &Status{ + Code: StatusOK, + Message: "Access permitted", + }, + Obligations: []Obligation{ + { + ID: "log-access", + Attributes: map[string]any{ + "level": "info", + }, + }, + }, + Advice: []Advice{ + { + ID: "remind-confidentiality", + Attributes: map[string]any{ + "message": "This document is confidential", + }, + }, + }, + EvaluatedAt: time.Now(), + PolicyIdReferences: []PolicyIdReference{ + {ID: "policy1", Version: "1.0"}, + {ID: "policy2", Version: "1.0"}, + }, + }, + }, + + "should return deny decision when policy evaluation denies access": { + request: standardRequest, + policyResolverConfigs: []*mockPolicyResolverConfig{ + {policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "1.0"}}}, + }, + policyProviderResponses: []policyprovider.PolicyResponse{policyResponses[0]}, + evaluationResult: &EvaluationResult{ + Decision: Deny, + Status: Status{ + Code: StatusOK, + Message: "Access denied", + }, + }, + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: Deny, + Status: &Status{ + Code: StatusOK, + Message: "Access denied", + }, + EvaluatedAt: time.Now(), + PolicyIdReferences: []PolicyIdReference{ + {ID: "policy1", Version: "1.0"}, + }, + }, + }, + + "should handle multiple resolvers returning unique policies": { + request: standardRequest, + policyResolverConfigs: []*mockPolicyResolverConfig{ + {policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "1.0"}, {ID: "policy2", Version: "1.0"}}}, + }, + policyProviderResponses: policyResponses, + evaluationResult: &EvaluationResult{ + Decision: Permit, + Status: Status{ + Code: StatusOK, + Message: "Access permitted", + }, + }, + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: Permit, + Status: &Status{ + Code: StatusOK, + Message: "Access permitted", + }, + EvaluatedAt: time.Now(), + PolicyIdReferences: []PolicyIdReference{ + {ID: "policy1", Version: "1.0"}, + {ID: "policy2", Version: "1.0"}, + }, + }, + }, + + "should return error when resolvers return duplicate policy with same version": { + request: standardRequest, + policyResolverConfigs: []*mockPolicyResolverConfig{ + {policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "1.0"}}}, + {policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "1.0"}}}, + }, + policyResolveErr: true, + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: Indeterminate, + Status: &Status{ + Code: StatusProcessingError, + Message: "Failed to resolve policies: duplicate policy reference detected: policy 'policy1' version '1.0' returned by multiple processors", + }, + EvaluatedAt: time.Now(), + }, + }, + + "should return error when resolvers return duplicate policy with different versions": { + request: standardRequest, + policyResolverConfigs: []*mockPolicyResolverConfig{ + {delay: 0, policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "1.0"}}}, + {delay: 50 * time.Millisecond, policyIdRefs: []PolicyIdReference{{ID: "policy1", Version: "2.0"}}}, + }, + policyResolveErr: true, + expectedResponse: &DecisionResponse{ + RequestID: fixedUUID, + Decision: Indeterminate, + Status: &Status{ + Code: StatusProcessingError, + Message: "Failed to resolve policies: duplicate policy ID 'policy1' found: existing version '1.0', conflicting version '2.0'", + }, + EvaluatedAt: time.Now(), + }, + }, + } + + // Run tests + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + // Create mocks + mockProvider := new(mockPolicyProvider) + mockEvaluator := new(mockPolicyEvaluator) + resolvers := make([]*mockPolicyResolver, 0) + + // Create processors and configure mocks + options := make([]Option, 0) + var allPolicyIdRefs []PolicyIdReference + + if tc.request != nil { + for _, config := range tc.policyResolverConfigs { + resolver := &mockPolicyResolver{delay: config.delay} + + resolver.On("Resolve", mock.Anything, tc.request).Return(config.policyIdRefs, config.err) + resolvers = append(resolvers, resolver) + allPolicyIdRefs = append(allPolicyIdRefs, config.policyIdRefs...) + options = append(options, WithPolicyResolver(resolver)) + } + + if len(allPolicyIdRefs) > 0 && tc.policyResolveErr == false { + mockProvider.On( + "GetPolicies", + mock.Anything, + mock.Anything, + ).Return(tc.policyProviderResponses, tc.policyRetrievalFail) + + if tc.policyRetrievalFail == nil { + mockEvaluator.On( + "Evaluate", + mock.Anything, + tc.request, + mock.Anything, + ).Return(tc.evaluationResult, tc.evaluatorError) + } + } + } + + // Create the decision maker with processors + dm := NewDecisionMaker(mockProvider, mockEvaluator, options...) + + // Execute the method under test + response, err := dm.MakeDecision(context.Background(), tc.request) + + // Verify error + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + return + } + + // Verify response using our custom assertion + assertResponseMatch(t, tc.expectedResponse, response) + assert.NoError(t, err) + + // Verify that all expected mock calls were made + mockProvider.AssertExpectations(t) + mockEvaluator.AssertExpectations(t) + for _, resolver := range resolvers { + resolver.AssertExpectations(t) + } + }) + } +} + +// TestDecisionMaker_GetPolicies tests the getPolicies helper method +func TestDecisionMaker_GetPolicies(t *testing.T) { + tests := map[string]struct { + policyRefs []PolicyIdReference + providerResponses []policyprovider.PolicyResponse + providerError error + expectedPolicies []Policy + expectedError string + }{ + "should return error when empty policy references list provided": { + policyRefs: []PolicyIdReference{}, + expectedError: "no policy references provided", + }, + + "should propagate policy provider error": { + policyRefs: []PolicyIdReference{ + {ID: "policy1", Version: "1.0"}, + {ID: "policy2", Version: "1.0"}, + }, + providerError: errors.New("provider error"), + expectedError: "failed to retrieve policies: provider error", + }, + + "should successfully convert provider responses to policies": { + policyRefs: []PolicyIdReference{ + {ID: "policy1", Version: "1.0"}, + {ID: "policy2", Version: "1.0"}, + }, + providerResponses: []policyprovider.PolicyResponse{ + { + ID: "policy1", + Version: "1.0", + Content: []byte(`{"policy": "content1"}`), + }, + { + ID: "policy2", + Version: "1.0", + Content: []byte(`{"policy": "content2"}`), + }, + }, + expectedPolicies: []Policy{ + { + ID: "policy1", + Version: "1.0", + Content: []byte(`{"policy": "content1"}`), + }, + { + ID: "policy2", + Version: "1.0", + Content: []byte(`{"policy": "content2"}`), + }, + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + // Create mock provider + mockProvider := new(mockPolicyProvider) + mockEvaluator := new(mockPolicyEvaluator) + + // Configure provider mock if needed + if len(tc.policyRefs) > 0 { + var policyRequests []policyprovider.GetPolicyRequest + for _, ref := range tc.policyRefs { + policyRequests = append(policyRequests, policyprovider.GetPolicyRequest{ + ID: ref.ID, + Version: ref.Version, + }) + } + mockProvider.On("GetPolicies", mock.Anything, policyRequests). + Return(tc.providerResponses, tc.providerError) + } + + // Create decision maker + dm := NewDecisionMaker(mockProvider, mockEvaluator).(*decisionMaker) + + // Call the private method via type assertion + policies, err := dm.getPolicies(context.Background(), tc.policyRefs) + + // Verify results + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.expectedPolicies, policies) + + // Verify mock expectations + mockProvider.AssertExpectations(t) + }) + } +} + +// Custom assertion helper to compare DecisionResponse objects with time tolerance +func assertResponseMatch(t *testing.T, expected, actual *DecisionResponse) { + require.NotNil(t, actual, "actual response should not be nil") + require.NotNil(t, expected, "expected response should not be nil") + + assert.Equal(t, expected.RequestID, actual.RequestID) + assert.Equal(t, expected.Decision, actual.Decision) + + // Compare status + if expected.Status != nil { + require.NotNil(t, actual.Status) + assert.Equal(t, expected.Status.Code, actual.Status.Code) + assert.Equal(t, expected.Status.Message, actual.Status.Message) + } else { + assert.Nil(t, actual.Status) + } + + // Compare obligations and advice + assert.Equal(t, expected.Obligations, actual.Obligations) + assert.Equal(t, expected.Advice, actual.Advice) + + // Check policy references (order might vary, so check length and contents) + assert.Len(t, actual.PolicyIdReferences, len(expected.PolicyIdReferences)) + for _, expectedPolicy := range expected.PolicyIdReferences { + found := false + for _, actualPolicy := range actual.PolicyIdReferences { + if expectedPolicy.ID == actualPolicy.ID && expectedPolicy.Version == actualPolicy.Version { + found = true + break + } + } + assert.True(t, found, "Expected policy reference %+v not found in actual response", expectedPolicy) + } + + // Check that EvaluatedAt is recent (within last 5 seconds) + assert.WithinDuration(t, time.Now(), actual.EvaluatedAt, 5*time.Second) +} + +// TestDecision_UnmarshalJSON tests the Decision type's UnmarshalJSON method +func TestDecision_UnmarshalJSON(t *testing.T) { + tests := map[string]struct { + input string + expected Decision + expectedError string + }{ + "should unmarshal valid Permit decision": { + input: `"Permit"`, + expected: Permit, + }, + "should unmarshal valid Deny decision": { + input: `"Deny"`, + expected: Deny, + }, + "should unmarshal valid Indeterminate decision": { + input: `"Indeterminate"`, + expected: Indeterminate, + }, + "should unmarshal valid NotApplicable decision": { + input: `"NotApplicable"`, + expected: NotApplicable, + }, + "should return error for invalid decision value": { + input: `"Invalid"`, + expectedError: `invalid decision value: "Invalid", must be one of: Permit, Deny, Indeterminate, NotApplicable`, + }, + "should return error for empty string": { + input: `""`, + expectedError: "invalid decision value: \"\", must be one of: Permit, Deny, Indeterminate, NotApplicable", + }, + "should return error for lowercase decision": { + input: `"permit"`, + expectedError: `invalid decision value: "permit", must be one of: Permit, Deny, Indeterminate, NotApplicable`, + }, + "should return error for mixed case decision": { + input: `"PERMIT"`, + expectedError: `invalid decision value: "PERMIT", must be one of: Permit, Deny, Indeterminate, NotApplicable`, + }, + "should return error for numeric input": { + input: `123`, + expectedError: "json: cannot unmarshal number into Go value of type string", + }, + "should return error for boolean input": { + input: `true`, + expectedError: "json: cannot unmarshal bool into Go value of type string", + }, + "should return error for object input": { + input: `{"decision": "Permit"}`, + expectedError: "json: cannot unmarshal object into Go value of type string", + }, + "should return error for array input": { + input: `["Permit"]`, + expectedError: "json: cannot unmarshal array into Go value of type string", + }, + "should return error for null input": { + input: `null`, + expectedError: "invalid decision value: \"\", must be one of: Permit, Deny, Indeterminate, NotApplicable", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + var decision Decision + err := json.Unmarshal([]byte(tc.input), &decision) + + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.expected, decision) + }) + } +} + +// TestStatusCode_UnmarshalJSON tests the StatusCode type's UnmarshalJSON method +func TestStatusCode_UnmarshalJSON(t *testing.T) { + tests := map[string]struct { + input string + expected StatusCode + expectedError string + }{ + "should unmarshal valid StatusOK": { + input: `"OK"`, + expected: StatusOK, + }, + "should unmarshal valid StatusMissingAttribute": { + input: `"AttributeMissing"`, + expected: StatusMissingAttribute, + }, + "should unmarshal valid StatusProcessingError": { + input: `"ProcessingError"`, + expected: StatusProcessingError, + }, + "should unmarshal valid StatusEvaluationError": { + input: `"EvaluationError"`, + expected: StatusEvaluationError, + }, + "should unmarshal valid StatusPolicyNotFound": { + input: `"PolicyNotFound"`, + expected: StatusPolicyNotFound, + }, + "should return error for invalid status code": { + input: `"INVALID_STATUS"`, + expectedError: "invalid status value: \"INVALID_STATUS\", must be one of: OK, AttributeMissing, ProcessingError, InvalidRequest, PolicyNotFound, EvaluationError", + }, + "should return error for empty string": { + input: `""`, + expectedError: "invalid status value: \"\", must be one of: OK, AttributeMissing, ProcessingError, InvalidRequest, PolicyNotFound, EvaluationError", + }, + "should return error for lowercase status": { + input: `"ok"`, + expectedError: "invalid status value: \"ok\", must be one of: OK, AttributeMissing, ProcessingError, InvalidRequest, PolicyNotFound, EvaluationError", + }, + "should return error for mixed case status": { + input: `"Ok"`, + expectedError: "invalid status value: \"Ok\", must be one of: OK, AttributeMissing, ProcessingError, InvalidRequest, PolicyNotFound, EvaluationError", + }, + "should return error for numeric input": { + input: `200`, + expectedError: "json: cannot unmarshal number into Go value of type string", + }, + "should return error for boolean input": { + input: `false`, + expectedError: "json: cannot unmarshal bool into Go value of type string", + }, + "should return error for object input": { + input: `{"status": "OK"}`, + expectedError: "json: cannot unmarshal object into Go value of type string", + }, + "should return error for array input": { + input: `["OK"]`, + expectedError: "json: cannot unmarshal array into Go value of type string", + }, + "should return error for null input": { + input: `null`, + expectedError: "invalid status value: \"\", must be one of: OK, AttributeMissing, ProcessingError, InvalidRequest, PolicyNotFound, EvaluationError", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + var statusCode StatusCode + err := json.Unmarshal([]byte(tc.input), &statusCode) + + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, statusCode) + } + }) + } +} diff --git a/abac/decisionmaker/policyevaluator.go b/abac/decisionmaker/policyevaluator.go new file mode 100644 index 0000000..34157c8 --- /dev/null +++ b/abac/decisionmaker/policyevaluator.go @@ -0,0 +1,19 @@ +package decisionmaker + +import ( + "context" +) + +// EvaluationResult contains the output from the policy evaluation process +type EvaluationResult struct { + Decision Decision `json:"decision"` + Status Status `json:"status"` + Obligations []Obligation `json:"obligations,omitempty"` + Advice []Advice `json:"advice,omitempty"` +} + +// PolicyEvaluator defines the interface for components that evaluate policies against decision requests to produce authorization decisions +type PolicyEvaluator interface { + // Evaluate evaluates a decision request against a set of policies and returns an evaluation result or an error. + Evaluate(ctx context.Context, req *DecisionRequest, policies []Policy) (*EvaluationResult, error) +} diff --git a/abac/decisionmaker/policyevaluator/opa/evaluator.go b/abac/decisionmaker/policyevaluator/opa/evaluator.go new file mode 100644 index 0000000..7891a39 --- /dev/null +++ b/abac/decisionmaker/policyevaluator/opa/evaluator.go @@ -0,0 +1,79 @@ +package opa + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/CameronXie/access-control-explorer/abac/decisionmaker" + "github.com/open-policy-agent/opa/v1/rego" +) + +// evaluator implements PolicyEvaluator using Open Policy Agent (OPA) Rego +type evaluator struct { + query string +} + +// NewEvaluator creates a PolicyEvaluator instance with the specified Rego query +func NewEvaluator(query string) decisionmaker.PolicyEvaluator { + return &evaluator{ + query: query, + } +} + +// Evaluate executes policies against a decision request using OPA Rego engine +func (e *evaluator) Evaluate( + ctx context.Context, + req *decisionmaker.DecisionRequest, + policies []decisionmaker.Policy, +) (*decisionmaker.EvaluationResult, error) { + if req == nil { + return nil, errors.New("decision request cannot be nil") + } + + if len(policies) == 0 { + return nil, errors.New("no policies provided for evaluation") + } + + // Build Rego configuration + regoArgs := []func(*rego.Rego){ + rego.Query(e.query), + rego.Input(req), + } + + // Add policies as Rego modules + for _, policy := range policies { + moduleName := fmt.Sprintf("policy_%s", policy.ID) + regoArgs = append(regoArgs, rego.Module(moduleName, string(policy.Content))) + } + + // Execute policy evaluation + instance := rego.New(regoArgs...) + resultSet, err := instance.Eval(ctx) + if err != nil { + return nil, fmt.Errorf("policy evaluation failed: %w", err) + } + + if len(resultSet) == 0 || len(resultSet[0].Expressions) == 0 { + return nil, errors.New("no evaluation results returned from policy engine") + } + + // Convert result to EvaluationResult + return convertResult(resultSet[0].Expressions[0].Value) +} + +// convertResult transforms OPA evaluation output to EvaluationResult struct +func convertResult(value any) (*decisionmaker.EvaluationResult, error) { + resultBytes, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("failed to marshal evaluation result: %w", err) + } + + var result decisionmaker.EvaluationResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal evaluation result: %w", err) + } + + return &result, nil +} diff --git a/abac/decisionmaker/policyevaluator/opa/evaluator_test.go b/abac/decisionmaker/policyevaluator/opa/evaluator_test.go new file mode 100644 index 0000000..2686045 --- /dev/null +++ b/abac/decisionmaker/policyevaluator/opa/evaluator_test.go @@ -0,0 +1,290 @@ +package opa + +import ( + "context" + "testing" + + "github.com/CameronXie/access-control-explorer/abac/decisionmaker" + "github.com/stretchr/testify/assert" +) + +func TestEvaluator_Evaluate(t *testing.T) { + tests := map[string]struct { + query string + request *decisionmaker.DecisionRequest + policies []decisionmaker.Policy + expectedResult *decisionmaker.EvaluationResult + expectedError string + }{ + "nil request should return error": { + query: "data.abac.result", + request: nil, + policies: []decisionmaker.Policy{ + getSubjectPolicy(), + }, + expectedError: "decision request cannot be nil", + }, + + "empty policies should return error": { + query: "data.abac.result", + request: newTestRequest([]string{"admin"}, "read"), + policies: []decisionmaker.Policy{}, + expectedError: "no policies provided for evaluation", + }, + + "invalid policy should return error": { + query: "data.abac.result", + request: newTestRequest([]string{"admin"}, "read"), + policies: []decisionmaker.Policy{{ + ID: "invalid", + Content: []byte("package"), + }}, + expectedError: `policy evaluation failed: 1 error occurred: policy_invalid:1: rego_parse_error: unexpected eof token`, + }, + + "policy has no result should return error": { + query: "data.abac.result", + request: newTestRequest([]string{"admin"}, "read"), + policies: []decisionmaker.Policy{{ + ID: "invalid", + Content: []byte("package abac"), + }}, + expectedError: "no evaluation results returned from policy engine", + }, + + "admin user should get permit decision with obligations": { + query: "data.abac.result", + request: newTestRequest([]string{"admin"}, "read"), + policies: []decisionmaker.Policy{getSubjectPolicy()}, + expectedResult: &decisionmaker.EvaluationResult{ + Decision: decisionmaker.Permit, + Status: decisionmaker.Status{ + Code: "OK", + Message: "Access granted - Administrative privileges verified for user", + }, + Obligations: []decisionmaker.Obligation{ + { + ID: "audit_logging", + Attributes: map[string]any{ + "level": "INFO", + "message": "Administrative access granted to user with verified admin role", + }, + }, + }, + }, + }, + + "customer with update action should get deny": { + query: "data.abac.result", + request: newTestRequest([]string{"customer"}, "update"), + policies: []decisionmaker.Policy{getSubjectPolicy(), getResourcePolicy()}, + expectedResult: &decisionmaker.EvaluationResult{ + Decision: decisionmaker.Deny, + Status: decisionmaker.Status{ + Code: "OK", + Message: "Access denied - Customers are not authorized to update product information", + }, + Obligations: []decisionmaker.Obligation{ + { + ID: "audit_logging", + Attributes: map[string]any{ + "level": "WARN", + "message": "Customer attempted unauthorized product update operation", + }, + }, + }, + }, + }, + + "customer with create action should get not applicable": { + query: "data.abac.result", + request: newTestRequest([]string{"customer"}, "create"), + policies: []decisionmaker.Policy{getSubjectPolicy()}, + expectedResult: &decisionmaker.EvaluationResult{ + Decision: decisionmaker.NotApplicable, + Status: decisionmaker.Status{ + Code: "PolicyNotFound", + Message: "No applicable access control policy found for the requested resource and action", + }, + Obligations: []decisionmaker.Obligation{ + { + ID: "audit_logging", + Attributes: map[string]any{ + "level": "WARN", + "message": "Access request processed without matching any specific authorization policy", + }, + }, + }, + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + result, err := NewEvaluator(tc.query).Evaluate(context.Background(), tc.request, tc.policies) + + // Verify error cases + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + return + } + + // Verify success cases + assert.NoError(t, err) + assert.EqualValues(t, *tc.expectedResult, *result) + }) + } +} + +func TestEvaluator_ConvertResult(t *testing.T) { + tests := map[string]struct { + input any + expectedError string + expectedResult *decisionmaker.EvaluationResult + }{ + // JSON unmarshaling error test + "invalid input should return error": { + input: `{`, + expectedError: "failed to unmarshal evaluation result: json: cannot unmarshal string", + expectedResult: nil, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + result, err := convertResult(tc.input) + + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + return + } + + assert.NoError(t, err) + assert.EqualValues(t, *tc.expectedResult, *result) + }) + } +} + +// newTestRequest creates a standard test request with specified roles and action +func newTestRequest(roles []string, actionID string) *decisionmaker.DecisionRequest { + return &decisionmaker.DecisionRequest{ + Subject: decisionmaker.Subject{ + ID: "user123", + Type: "user", + Attributes: map[string]any{ + "roles": roles, + }, + }, + Action: decisionmaker.Action{ + ID: actionID, + }, + Resource: decisionmaker.Resource{ + ID: "product-123", + Type: "product", + Attributes: map[string]any{ + "sku": "123456", + "price": 123.45, + }, + }, + } +} + +// getSubjectPolicy returns a Rego policy that handles subject-based authorization +// Includes admin access rules and default fallback behavior +func getSubjectPolicy() decisionmaker.Policy { + content := ` +package abac + +# Default policy result when no specific rules match +# This ensures every evaluation returns a decision rather than undefined +default result := { + "decision": "NotApplicable", + "status": { + "code": "PolicyNotFound", + "message": "No applicable access control policy found for the requested resource and action", + }, + "obligations": [{ + "id": "audit_logging", + "attributes": { + "level": "WARN", + "message": "Access request processed without matching any specific authorization policy", + }, + }], +} + +# Administrative access rule +# Grants full access to users with admin role in their attributes +# Returns Permit decision with mandatory audit logging obligation +result := { + "decision": "Permit", + "status": { + "code": "OK", + "message": "Access granted - Administrative privileges verified for user", + }, + "obligations": [{ + "id": "audit_logging", + "attributes": { + "level": "INFO", + "message": "Administrative access granted to user with verified admin role" + }, + }], +} if { + user_is_admin +} + +# Helper rule to determine if the requesting user has administrative privileges +# Checks if "admin" role exists in the subject's role attributes +# This separation makes the policy more readable and maintainable +user_is_admin if { + "admin" in input.subject.attributes.roles +}` + return decisionmaker.Policy{ + ID: "subject-policy", + Version: "1.0", + Content: []byte(content), + } +} + +// getResourcePolicy returns a Rego policy that handles resource-specific restrictions +// Implements customer access limitations for product updates +func getResourcePolicy() decisionmaker.Policy { + content := ` +package abac + +# Customer access restriction rule for product updates +# Denies update operations on product resources when requested by customer-only users +# This enforces business rule that customers cannot modify product information +result := { + "decision": "Deny", + "status": { + "code": "OK", + "message": "Access denied - Customers are not authorized to update product information", + }, + "obligations": [{ + "id": "audit_logging", + "attributes": { + "level": "WARN", + "message": "Customer attempted unauthorized product update operation", + }, + }], +} if { + user_is_customer + input.action.id == "update" + input.resource.type == "product" +} + +# Helper rule to identify customer-only users +# This ensures users with customer + other roles (like admin) are not restricted +user_is_customer if { + "customer" in input.subject.attributes.roles + count(input.subject.attributes.roles) == 1 +} +` + return decisionmaker.Policy{ + ID: "resource-policy", + Version: "1.0", + Content: []byte(content), + } +} diff --git a/abac/infoprovider/infoprovider.go b/abac/infoprovider/infoprovider.go new file mode 100644 index 0000000..2d9f121 --- /dev/null +++ b/abac/infoprovider/infoprovider.go @@ -0,0 +1,17 @@ +package infoprovider + +import "context" + +type GetInfoRequest struct { + InfoType string + Params any + Context map[string]string +} + +type GetInfoResponse struct { + Info map[string]any +} + +type InfoProvider interface { + GetInfo(ctx context.Context, req *GetInfoRequest) (*GetInfoResponse, error) +} diff --git a/abac/policyprovider/filestore/policyprovider.go b/abac/policyprovider/filestore/policyprovider.go new file mode 100644 index 0000000..537dee7 --- /dev/null +++ b/abac/policyprovider/filestore/policyprovider.go @@ -0,0 +1,69 @@ +package filestore + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/CameronXie/access-control-explorer/abac/policyprovider" +) + +// policyProvider implements the PolicyProvider interface using the local filesystem +type policyProvider struct { + basePath string // Base directory where policies are stored +} + +// New creates a new filesystem-based PolicyProvider +// The basePath parameter specifies the root directory for policy files +func New(basePath string) policyprovider.PolicyProvider { + return &policyProvider{ + basePath: basePath, + } +} + +// GetPolicies retrieves multiple policies from the filesystem +func (p *policyProvider) GetPolicies(ctx context.Context, reqs []policyprovider.GetPolicyRequest) ([]policyprovider.PolicyResponse, error) { + policies := make([]policyprovider.PolicyResponse, 0, len(reqs)) + + for _, req := range reqs { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + policy, err := p.getPolicy(req) + if err != nil { + return nil, fmt.Errorf("failed to get policy %s@%s: %w", req.ID, req.Version, err) + } + + policies = append(policies, *policy) + } + } + + return policies, nil +} + +// getPolicy retrieves a single policy from the filesystem +func (p *policyProvider) getPolicy(req policyprovider.GetPolicyRequest) (*policyprovider.PolicyResponse, error) { + policyPath := filepath.Join(p.basePath, req.Version, req.ID) + + fileInfo, err := os.Stat(policyPath) + if err != nil { + return nil, fmt.Errorf("policy not found: %w", err) + } + + if fileInfo.IsDir() { + return nil, fmt.Errorf("policy path is a directory, not a file") + } + + content, err := os.ReadFile(policyPath) + if err != nil { + return nil, fmt.Errorf("failed to read policy: %w", err) + } + + return &policyprovider.PolicyResponse{ + ID: req.ID, + Version: req.Version, + Content: content, + }, nil +} diff --git a/abac/policyprovider/filestore/policyprovider_test.go b/abac/policyprovider/filestore/policyprovider_test.go new file mode 100644 index 0000000..aac03b8 --- /dev/null +++ b/abac/policyprovider/filestore/policyprovider_test.go @@ -0,0 +1,143 @@ +package filestore + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/CameronXie/access-control-explorer/abac/policyprovider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPolicyProvider_GetPolicies(t *testing.T) { + tempDir := setupTestDir(t) + defer func() { + require.NoError(t, os.RemoveAll(tempDir)) + }() + + testCases := map[string]struct { + basePath string + requests []policyprovider.GetPolicyRequest + setupContext func() context.Context + expectedResult []policyprovider.PolicyResponse + expectedError string + }{ + "should retrieve multiple policies successfully": { + basePath: tempDir, + requests: []policyprovider.GetPolicyRequest{ + {ID: "policy1", Version: "v1"}, + {ID: "policy2", Version: "v1"}, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []policyprovider.PolicyResponse{ + {ID: "policy1", Version: "v1", Content: []byte("policy1 content")}, + {ID: "policy2", Version: "v1", Content: []byte("policy2 content")}, + }, + }, + + "should retrieve single policy successfully": { + basePath: tempDir, + requests: []policyprovider.GetPolicyRequest{ + {ID: "policy1", Version: "v2"}, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []policyprovider.PolicyResponse{ + {ID: "policy1", Version: "v2", Content: []byte("policy1 v2 content")}, + }, + }, + + "should return error when policy does not exist": { + basePath: tempDir, + requests: []policyprovider.GetPolicyRequest{ + {ID: "nonexistent", Version: "v1"}, + }, + setupContext: func() context.Context { return context.Background() }, + expectedError: "failed to get policy nonexistent@v1: policy not found", + }, + + "should return error when policy path is directory": { + basePath: tempDir, + requests: []policyprovider.GetPolicyRequest{ + {ID: "dir-policy", Version: "v1"}, + }, + setupContext: func() context.Context { return context.Background() }, + expectedError: "failed to get policy dir-policy@v1: policy path is a directory, not a file", + }, + + "should return error when context is cancelled": { + basePath: tempDir, + requests: []policyprovider.GetPolicyRequest{ + {ID: "policy1", Version: "v1"}, + }, + setupContext: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + expectedError: "context canceled", + }, + + "should return error when base path does not exist": { + basePath: "/nonexistent/path", + requests: []policyprovider.GetPolicyRequest{ + {ID: "policy1", Version: "v1"}, + }, + setupContext: func() context.Context { return context.Background() }, + expectedError: "failed to get policy policy1@v1: policy not found", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + provider := New(tc.basePath) + + // Setup context + ctx := tc.setupContext() + + // Execute + result, err := provider.GetPolicies(ctx, tc.requests) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.ElementsMatch(t, tc.expectedResult, result) + } + }) + } +} + +func setupTestDir(t *testing.T) string { + // Create a temporary test directory + tempDir, err := os.MkdirTemp("", "policy-test") + require.NoError(t, err) + + // Create version directories + v1Dir := filepath.Join(tempDir, "v1") + v2Dir := filepath.Join(tempDir, "v2") + + for _, dir := range []string{v1Dir, v2Dir} { + require.NoError(t, os.MkdirAll(dir, 0755)) + } + + // Create policy files + policies := map[string][]byte{ + filepath.Join(v1Dir, "policy1"): []byte("policy1 content"), + filepath.Join(v1Dir, "policy2"): []byte("policy2 content"), + filepath.Join(v2Dir, "policy1"): []byte("policy1 v2 content"), + } + + for path, content := range policies { + require.NoError(t, os.WriteFile(path, content, 0644)) //nolint:gosec // unit test + } + + // Create a directory instead of a file to test error case + dirPolicyPath := filepath.Join(v1Dir, "dir-policy") + require.NoError(t, os.MkdirAll(dirPolicyPath, 0755)) + + return tempDir +} diff --git a/abac/policyprovider/policyprovider.go b/abac/policyprovider/policyprovider.go new file mode 100644 index 0000000..1cf09fc --- /dev/null +++ b/abac/policyprovider/policyprovider.go @@ -0,0 +1,23 @@ +package policyprovider + +import "context" + +// GetPolicyRequest represents a request to retrieve a specific policy +type GetPolicyRequest struct { + ID string + Version string +} + +// PolicyResponse contains a policy's metadata and content +type PolicyResponse struct { + ID string + Version string + Content []byte +} + +// PolicyProvider defines the interface for retrieving policies +type PolicyProvider interface { + // GetPolicies retrieves multiple policies in a single call + // Returns policy responses for each request or an error if retrieval fails + GetPolicies(ctx context.Context, reqs []GetPolicyRequest) ([]PolicyResponse, error) +} diff --git a/abac/requestorchestrator/requestorchestrator.go b/abac/requestorchestrator/requestorchestrator.go new file mode 100644 index 0000000..f198653 --- /dev/null +++ b/abac/requestorchestrator/requestorchestrator.go @@ -0,0 +1,82 @@ +package requestorchestrator + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +type Decision string + +const ( + Permit Decision = "Permit" // Request is allowed + Deny Decision = "Deny" // Request is denied + Indeterminate Decision = "Indeterminate" // Errors prevented making a decision + NotApplicable Decision = "NotApplicable" // No applicable policy was found +) + +type StatusCode string + +const ( + StatusOK StatusCode = "OK" // Decision was successfully evaluated + StatusMissingAttribute StatusCode = "AttributeMissing" // A required attribute is missing + StatusProcessingError StatusCode = "ProcessingError" // An internal processing error occurred + StatusInvalidRequest StatusCode = "InvalidRequest" // The request is malformed + StatusPolicyNotFound StatusCode = "PolicyNotFound" // No matching policies were found + StatusEvaluationError StatusCode = "EvaluationError" // General evaluation error +) + +type Subject struct { + ID string `json:"id"` + Type string `json:"type"` +} + +type Action struct { + ID string `json:"id"` +} + +type Resource struct { + ID string `json:"id"` + Type string `json:"type"` +} + +type AccessRequest struct { + Subject Subject `json:"subject"` + Action Action `json:"action"` + Resource Resource `json:"resource"` +} + +type Obligation struct { + ID string `json:"id"` + Attributes map[string]any `json:"attributes,omitempty"` +} + +type Advice struct { + ID string `json:"id"` + Attributes map[string]any `json:"attributes,omitempty"` +} + +type Status struct { + Code StatusCode `json:"code"` + Message string `json:"message"` +} + +type PolicyIdReference struct { + ID string `json:"id"` + Version string `json:"version"` +} + +type AccessResponse struct { + RequestID uuid.UUID `json:"requestId"` + Decision Decision `json:"decision"` + Status Status `json:"status"` + Obligations []Obligation `json:"obligations,omitempty"` + Advices []Advice `json:"advices,omitempty"` + EvaluatedAt time.Time `json:"evaluatedAt"` + PolicyIdReferences []PolicyIdReference `json:"policyIdReferences"` +} + +type RequestOrchestrator interface { + EvaluateAccess(ctx context.Context, req *AccessRequest) (*AccessResponse, error) +} diff --git a/docker-compose.yaml b/docker-compose.yaml index 768d8a8..cd1d894 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -14,16 +14,19 @@ services: ports: - '8080:8080' depends_on: - - mysql + - postgres - mysql: - container_name: access_control_explorer_sql - image: mysql:9.1 + postgres: + container_name: access_control_explorer_postgres + image: postgres:17.4-bookworm + volumes: + - ./docker/postgres/init-db.sh:/docker-entrypoint-initdb.d/init-db.sh environment: - MYSQL_DATABASE: ${MYSQL_CASBIN_DATABASE} - MYSQL_USER: ${MYSQL_USER} - MYSQL_PASSWORD: ${MYSQL_PASSWORD} - MYSQL_ALLOW_EMPTY_PASSWORD: 1 + POSTGRES_USER: ${POSTGRES_USER} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + POSTGRES_DB: postgres + POSTGRES_DB_DEMO: ${POSTGRES_DB_DEMO:-demo_db} + POSTGRES_DB_TEST: ${POSTGRES_DB_TEST:-test_db} restart: always ports: - - "3306:3306" + - "5432:5432" diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile index f33ed3a..b9321cc 100644 --- a/docker/dev/Dockerfile +++ b/docker/dev/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.23-bookworm +FROM golang:1.24-bookworm RUN set -eux \ && apt-get update && apt-get install -y --no-install-recommends \ @@ -18,16 +18,20 @@ RUN set -eux \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* -ARG ACTIONLINT_VERSION=1.7.3 -ARG GOLANGCILINT_VERSION=1.61.0 +ARG ACTIONLINT_VERSION=1.7.7 +ARG GOLANGCILINT_VERSION=2.1.6 +ARG GOLANG_MIGRATE_VERSION=4.18.3 WORKDIR /tmp/build -RUN \ +RUN set -eux \ # install golangcli-lint - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | \ + && curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | \ sh -s -- -b $(go env GOPATH)/bin v${GOLANGCILINT_VERSION} \ # install actionlint && wget -q -O actionlint.tar.gz https://github.com/rhysd/actionlint/releases/download/v${ACTIONLINT_VERSION}/actionlint_${ACTIONLINT_VERSION}_linux_amd64.tar.gz \ && tar -xzf actionlint.tar.gz \ && mv actionlint /usr/local/bin \ + # install golang-migrate + && curl -L https://github.com/golang-migrate/migrate/releases/download/v${GOLANG_MIGRATE_VERSION}/migrate.linux-amd64.tar.gz | tar xvz \ + && mv migrate /usr/local/bin \ && rm -rf /tmp/build diff --git a/docker/postgres/init-db.sh b/docker/postgres/init-db.sh new file mode 100755 index 0000000..6db3263 --- /dev/null +++ b/docker/postgres/init-db.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -e + +# Create databases using environment variables +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "postgres" <<-EOSQL + CREATE DATABASE ${POSTGRES_DB_DEMO:-demo_db}; + CREATE DATABASE ${POSTGRES_DB_TEST:-test_db}; + GRANT ALL PRIVILEGES ON DATABASE ${POSTGRES_DB_DEMO:-demo_db} TO $POSTGRES_USER; + GRANT ALL PRIVILEGES ON DATABASE ${POSTGRES_DB_TEST:-test_db} TO $POSTGRES_USER; +EOSQL \ No newline at end of file diff --git a/examples/abac/.golangci.yml b/examples/abac/.golangci.yml new file mode 100644 index 0000000..ef2ca7e --- /dev/null +++ b/examples/abac/.golangci.yml @@ -0,0 +1,109 @@ +version: "2" + +linters: + default: none + enable: + - bodyclose + - copyloopvar + - dogsled + - dupl + - errcheck + - errorlint + - funlen + - gocheckcompilerdirectives + - gochecknoinits + - goconst + - gocritic + - gocyclo + - goprintffuncname + - gosec + - govet + - ineffassign + - intrange + - lll + - mnd + - nakedret + - noctx + - nolintlint + - revive + - staticcheck + - unconvert + - unparam + - unused + - whitespace + settings: + dupl: + threshold: 100 + funlen: + lines: -1 + statements: 50 + goconst: + min-len: 2 + min-occurrences: 3 + gocritic: + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + disabled-checks: + - dupImport + - ifElseChain + - octalLiteral + - whyNoLint + gocyclo: + min-complexity: 15 + lll: + line-length: 140 + mnd: + checks: + - argument + - case + - condition + - return + ignored-numbers: + - "0" + - "1" + - "2" + - "3" + ignored-functions: + - strings.SplitN + nolintlint: + require-explanation: true + require-specific: true + allow-unused: false + revive: + rules: + - name: indent-error-flow + - name: unexported-return + disabled: true + - name: unused-parameter + - name: unused-receiver + + exclusions: + generated: strict + warn-unused: true + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - path: "_test.go" + text: "unlambda" + linters: + - gocritic + +formatters: + enable: + - gofmt + - goimports + settings: + gofmt: + rewrite-rules: + - pattern: interface{} + replacement: any + exclusions: + generated: strict + warn-unused: true diff --git a/examples/abac/Makefile b/examples/abac/Makefile new file mode 100644 index 0000000..745692c --- /dev/null +++ b/examples/abac/Makefile @@ -0,0 +1,68 @@ +DIST_DIR := _dist +GO_CODE_DIR := pkg internal +TEST_OUTPUT_DIR := ${DIST_DIR}/tests +BUILD_DIR := ${DIST_DIR}/build + +# Database migrations +MIGRATIONS_DIR := database/migrations + +# Database seed +SEEDS_DIR := database/seeds +SEED_MIGRATIONS_TABLE := seed_version + +DEFAULT_VERSION := v0.0.0 + +.PHONY: migrate-db +migrate-db: + @migrate -source file://$(MIGRATIONS_DIR) \ + -database postgres://$(POSTGRES_USER):$(POSTGRES_PASSWORD)@$(POSTGRES_HOST)/$(POSTGRES_DB_DEMO)?sslmode=$(POSTGRES_SSL_MODE) \ + up + @migrate -source file://$(MIGRATIONS_DIR) \ + -database postgres://$(POSTGRES_USER):$(POSTGRES_PASSWORD)@$(POSTGRES_HOST)/$(POSTGRES_DB_TEST)?sslmode=$(POSTGRES_SSL_MODE) \ + up + +.PHONY: seed-db +seed-db: + @migrate -source file://$(SEEDS_DIR) \ + -database "postgres://$(POSTGRES_USER):$(POSTGRES_PASSWORD)@$(POSTGRES_HOST)/$(POSTGRES_DB_DEMO)?sslmode=$(POSTGRES_SSL_MODE)&x-migrations-table=$(SEED_MIGRATIONS_TABLE)" \ + up + +# App +.PHONY: test +test: migrate-db lint-go test-go + +.PHONY: build +build: cleanup-build + @CURRENT_VERSION=$(shell git describe --tags --abbrev=0 2>/dev/null || echo $(DEFAULT_VERSION)); \ + echo "Building api (version $$CURRENT_VERSION)..."; \ + go build -o $(BUILD_DIR)/api \ + -a -ldflags "-X 'github.com/CameronXie/access-control-explorer/examples/abac/internal/version.Version=$$CURRENT_VERSION' -extldflags '-s -w -static'" \ + ./cmd/api; \ + cp -r ./cmd/api/policies $(BUILD_DIR)/policies + +.PHONY: lint-go +lint-go: + @echo "Running Go linter on code in $(GO_CODE_DIR)..." + @golangci-lint fmt $(addsuffix /..., $(GO_CODE_DIR)) -v + @golangci-lint run $(addsuffix /..., $(GO_CODE_DIR)) -v + +.PHONY: test-go +test-go: + @rm -rf ${TEST_OUTPUT_DIR} + @mkdir -p ${TEST_OUTPUT_DIR} + @go clean -testcache + @echo "Running Go tests..." + @go test \ + -cover \ + -coverprofile=cp.out \ + -outputdir=${TEST_OUTPUT_DIR} \ + -race \ + -v \ + -failfast \ + $(addprefix `pwd`/, $(addsuffix /..., $(GO_CODE_DIR))) + @go tool cover -html=${TEST_OUTPUT_DIR}/cp.out -o ${TEST_OUTPUT_DIR}/cp.html + +.PHONY: cleanup-build +cleanup-build: + @rm -rf ${BUILD_DIR} + @mkdir -p ${BUILD_DIR} diff --git a/examples/abac/README.md b/examples/abac/README.md new file mode 100644 index 0000000..12eb5c4 --- /dev/null +++ b/examples/abac/README.md @@ -0,0 +1,353 @@ +# E-Commerce Store API + +[![Test](https://github.com/CameronXie/access-control-explorer/actions/workflows/test.yaml/badge.svg)](https://github.com/CameronXie/access-control-explorer/actions/workflows/test.yaml) + +## Overview + +The E-Commerce Store API demonstrates a comprehensive implementation of Attribute-Based Access Control (ABAC) for +authorisation in a REST API context, with support for Role-Based Access Control (RBAC) patterns. This example showcases +how ABAC can provide flexible and fine-grained access control while maintaining the familiar concepts of roles and +permissions. + +The implementation uses PostgreSQL for RBAC information storage with support for hierarchical role structures, OPA/Rego +for policy evaluation, and JWT-based authentication with automatic user context enrichment. This example is designed +for demonstration and educational purposes, featuring simplified authentication mechanisms that are not recommended for +production use. + +## Architecture + + +```mermaid +graph TB + Client[Client Application] + subgraph "API Gateway" + Auth[Auth Handler
/auth/signin] + Health[Health Check
/health] + end + + subgraph "Protected API" + JWT[JWT Auth Middleware
Token Validation] + Enforcer[Enforcer
Policy Enforcement Point] + StoreAPI[Store API
/api/v1] + end + + subgraph "ABAC Components" + RequestOrchestrator[Request Orchestrator
Context Handler] + DecisionMaker[Decision Maker
Policy Decision Point] + PolicyProvider[Policy Provider
Policy Retrieval Point] + InfoProviders[Info Providers
Policy Information Point] + end + + subgraph "Data Layer" + PostgreSQL[(PostgreSQL
Users, Orders, RBAC)] + RegoFiles[Rego Policies
default.rego, rbac.rego] + end + + subgraph "External Systems" + AuditLog[Audit Logging
Obligation Handler] + CacheHint[Cache Hint
Advice Handler] + end + + Client --> Auth + Client --> Health + Client --> JWT + + Auth --> PostgreSQL + JWT --> Enforcer + Enforcer --> RequestOrchestrator + RequestOrchestrator --> InfoProviders + RequestOrchestrator --> DecisionMaker + DecisionMaker --> PolicyProvider + PolicyProvider --> RegoFiles + InfoProviders --> PostgreSQL + StoreAPI --> PostgreSQL + + Enforcer --> StoreAPI + Enforcer --> AuditLog + Enforcer --> CacheHint + + classDef xacml fill:#e1f5fe + classDef storage fill:#f3e5f5 + classDef external fill:#e8f5e8 + classDef middleware fill:#fff3e0 + + class RequestOrchestrator,Enforcer,DecisionMaker,PolicyProvider,InfoProviders xacml + class PostgreSQL,RegoFiles storage + class AuditLog,CacheHint external + class JWT middleware +``` + +## Key Features + +- **XACML-Style Architecture**: Implementation follows XACML patterns with clear separation of Policy Enforcement + Point (PEP), Policy Decision Point (PDP), Policy Retrieval Point (PRP), and Context Handler +- **ABAC with RBAC Support**: Flexible attribute-based access control that naturally supports role-based patterns + through policy configuration +- **Role Hierarchy**: PostgreSQL-backed role hierarchy supporting inheritance and complex organisational structures +- **OPA Integration**: Policy evaluation using Open Policy Agent with Rego policy language +- **JWT Authentication**: RS256-signed JWT tokens for stateless authentication with automatic user context enrichment +- **Middleware Chain**: JWT authentication middleware validates tokens and enriches request context before ABAC + enforcement +- **Automatic Ownership**: Orders automatically inherit ownership from an authenticated user context +- **Obligations and Advices**: Support for policy-driven actions (audit logging) and hints (caching) +- **Comprehensive Logging**: Structured logging for operational observability and audit trails + +## API Endpoints + +### Authentication + +#### POST /auth/signin + +Simplified authentication endpoint for demonstration purposes. + +**⚠️ Warning**: This endpoint is for demo purposes only and lacks critical security features such as password +verification, rate limiting, and multifactor authentication. Do not use in production environments. + +**Request**: + +```json +{ + "email": "alice@abac.com" +} +``` + +**Response**: + +```json +{ + "token": "", + "token_type": "Bearer" +} +``` + +### Order Management + +#### POST /api/v1/orders + +Create a new order. Requires valid JWT authentication and appropriate permissions. The order automatically inherits +ownership from the authenticated user and is initialised with the "created" status. + +**Request**: + +```json +{ + "name": "order-123", + "attributes": { + "priority": "high", + "total_amount": "299.99" + } +} +``` + +**Response**: + +```json +{ + "id": "123e4567-e89b-12d3-a456-426614174000", + "name": "order-123", + "attributes": { + "owner": "", + "status": "created", + "priority": "high", + "total_amount": "299.99" + } +} +``` + +#### GET /api/v1/orders/{id} + +Retrieve an order by ID. Requires valid JWT authentication and appropriate permissions. + +**Response**: + +```json +{ + "id": "123e4567-e89b-12d3-a456-426614174000", + "name": "order-123", + "attributes": { + "owner": "", + "status": "created", + "priority": "high", + "total_amount": "299.99" + } +} +``` + +### Health Check + +#### GET /health + +Service health check endpoint. + +**Response**: + +```json +{ + "status": "healthy" +} +``` + +## Development Setup + +### Prerequisites + +- Docker and Docker Compose +- Make + +### Environment Configuration + +The application requires several environment variables for operation: + +- `POSTGRES_*`: Database connection parameters +- `PRIVATE_KEY_BASE64`: Base64-encoded RSA private key for JWT signing +- `PUBLIC_KEY_BASE64`: Base64-encoded RSA public key for JWT verification +- `JWT_ISSUER`: JWT token issuer identifier +- `JWT_AUDIENCE`: JWT token audience identifier +- `PORT`: HTTP server port (optional, defaults to 8080) + +### Local Development Environment + +Navigate to the project root and set up the development environment: + +```shell +make up +``` + +This command: + +- Generates RSA key pairs for JWT operations if not present +- Creates `.env` file from `.env.example` with generated keys +- Starts all required services via Docker Compose including PostgreSQL + +### Database Setup + +Initialise the database schema and seed data: + +```shell +# From within the Docker development container +make migrate-db +make seed-db +``` + +These commands create: + +- Role hierarchy (admin → customer_service → customer) +- Sample users with different role assignments +- Permission structures for order operations +- Demonstration orders with ownership attributes + +## Building and Running + +### Build the Application + +```shell +# From examples/abac directory +make build +``` + +Build artifacts are placed in `_dist/build/`: + +- `api`: Compiled application binary +- `policies/`: Rego policy files copied for runtime + +### Run the Application + +```shell +# Run with default configuration ./_dist/build/api +# Run with custom port +PORT=9090 ./_dist/build/api +``` + +The application logs its version and listening address on startup. + +## API Usage Examples + +### Authentication Flow + +```shell +curl -X POST http://localhost:8080/auth/signin +-H "Content-Type: application/json" +-d '{"email":"alice@abac.com"}' + +# Response includes token for subsequent requests +# { +# "token": "ey...", +# "token_type": "Bearer" +# } +``` + +### Order Operations + +```shell +# Create new order +curl -X POST http://localhost:8080/api/v1/orders +-H "Authorization: Bearer " +-H "Content-Type: application/json" +-d '{ "name": "laptop-order-001", "attributes": { "priority": "high", "category": "electronics" } }' + +# Retrieve order by ID +curl -X GET http://localhost:8080/api/v1/orders/ +-H "Authorization: Bearer " +``` + +### Health Check + +```shell +curl -X GET http://localhost:8080/health +``` + +## Testing + +### Unit Testing + +Run the complete test suite: + +```shell +make test +``` + +Run Go-specific tests with coverage: + +```shell +make test-go +``` + +Test artifacts including coverage reports are generated in `_dist/tests/`. + +## Policy Configuration + +The application uses two main Rego policy files: + +- `policies/default.rego`: Top-level policy combiner that merges subject and resource evaluation results +- `policies/rbac.rego`: Role-based access control implementation within ABAC framework + +Policy decisions trigger: + +- **Obligations**: `audit_logging` for access event logging +- **Advices**: `cache_hint` for client caching guidance via `X-ABAC-Decision-TTL` header + +## Troubleshooting + +### Common Issues + +**401 Unauthorized**: Verify JWT token validity and ensure proper Authorization header format +**403 Forbidden**: Check user roles and permissions in the database; verify policy evaluation +**500 Internal Server Error**: Review application logs for detailed error information +**Database Connection**: Ensure PostgreSQL is running and environment variables are correctly configured + +### Debug Information + +- Application logs provide structured output with request correlation IDs +- Policy decisions are logged with evaluation context +- Database queries include performance metrics +- JWT validation errors are logged with specific failure reasons + +## Security Considerations + +This example demonstrates ABAC implementation patterns but includes several simplifications for educational purposes: + +- **Authentication**: The signin endpoint lacks password verification, rate limiting, and account lockout protection +- **JWT Security**: Token rotation and revocation mechanisms are not implemented +- **Input Validation**: Production deployments should include comprehensive input sanitisation +- **Error Handling**: Internal error details are intentionally not exposed to clients +- **Audit Logging**: Production systems should implement comprehensive audit trails with tamper protection diff --git a/examples/abac/cmd/api/main.go b/examples/abac/cmd/api/main.go new file mode 100644 index 0000000..d5133ab --- /dev/null +++ b/examples/abac/cmd/api/main.go @@ -0,0 +1,255 @@ +package main + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/CameronXie/access-control-explorer/abac/decisionmaker" + "github.com/CameronXie/access-control-explorer/abac/decisionmaker/policyevaluator/opa" + ip "github.com/CameronXie/access-control-explorer/abac/infoprovider" + "github.com/CameronXie/access-control-explorer/abac/policyprovider/filestore" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/advice" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/api/rest/handler" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/api/rest/middleware" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/enforcer" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/enforcer/jwt" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/enforcer/operations" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/infoprovider" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/obligation" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/policyresolver" + repository "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository/postgres" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/requestorchestrator" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/requestorchestrator/infoanalyser" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/version" + "github.com/CameronXie/access-control-explorer/examples/abac/pkg/keyfetcher" + "github.com/jackc/pgx/v5/pgxpool" +) + +const ( + DefaultPort = "8080" + + PolicyDir = "policies" + RegoQuery = "data.abac.result" + DefaultPolicyKey = "default.rego" + RBACPolicyKey = "rbac.rego" + PolicyVersion = "v1" + + TokenTTL = 1 * time.Hour + DecisionCacheHintHeaderName = "X-ABAC-Decision-TTL" + + JWTClockSkewTolerance = 5 * time.Minute +) + +func main() { + logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + logger.Info("api_starting", "version", version.Version) + + // Database connection + dbPool, err := initializeDatabase(logger, fmt.Sprintf( + "postgres://%s:%s@%s/%s?sslmode=%s", + os.Getenv("POSTGRES_USER"), + os.Getenv("POSTGRES_PASSWORD"), + os.Getenv("POSTGRES_HOST"), + os.Getenv("POSTGRES_DB_DEMO"), + os.Getenv("POSTGRES_SSL"), + )) + if err != nil { + logger.Error("db_init_failed", "error", err) + os.Exit(1) + } + defer dbPool.Close() + + // Policy location + policyPath, err := resolvePolicyPath(PolicyDir, "POLICY_DIR") + if err != nil { + logger.Error("policy_path_resolve_failed", "error", err) + os.Exit(1) + } + + // Repositories + userRepo := repository.NewUserRepository(dbPool) + orderRepo := repository.NewOrderRepository(dbPool) + rbacRepo := repository.NewRBACRepository(dbPool) + + // Auth config + issuer := os.Getenv("JWT_ISSUER") + audience := os.Getenv("JWT_AUDIENCE") + + // Create JWT middleware + jwtMiddleware := middleware.NewJWTAuthMiddleware(middleware.JWTConfig{ + KeyFetcher: keyfetcher.FromBase64Env("PUBLIC_KEY_BASE64"), + Issuer: issuer, + Audience: audience, + ClockSkew: JWTClockSkewTolerance, + }) + + // Enforcer (PEP) + enforcerMiddleware, err := initEnforcer(policyPath, userRepo, orderRepo, rbacRepo, logger) + if err != nil { + logger.Error("enforcer_init_failed", "error", err) + os.Exit(1) + } + + // REST handlers + orderHandler := handler.NewOrderHandler(orderRepo, logger) + authHandler := handler.NewAuthHandler( + userRepo, + &handler.AuthConfig{ + KeyFetcher: keyfetcher.FromBase64Env("PRIVATE_KEY_BASE64"), + Issuer: issuer, + Audience: audience, + TokenTTL: TokenTTL, + }, + logger, + ) + + // Routing + mux := buildServeMux(authHandler, orderHandler, jwtMiddleware, enforcerMiddleware) + + // HTTP server with sensible timeouts + port := os.Getenv("PORT") + if port == "" { + port = DefaultPort + } + server := &http.Server{ + Addr: fmt.Sprintf(":%s", port), + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 20 * time.Second, + IdleTimeout: 60 * time.Second, + } + + logger.Info("api_listening", "addr", server.Addr) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Error("api_serve_failed", "error", err) + os.Exit(1) + } +} + +// initializeDatabase creates a pool and verifies connectivity. +func initializeDatabase(logger *slog.Logger, connectionString string) (*pgxpool.Pool, error) { + ctx := context.Background() + pool, err := pgxpool.New(ctx, connectionString) + if err != nil { + return nil, fmt.Errorf("create_pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("ping_db: %w", err) + } + + return pool, nil +} + +// initEnforcer wires PRP, PDP, Context Handler, and PEP middleware. +func initEnforcer( + policyPath string, + userRepo infoprovider.UserAttributesRepository, + orderRepo infoprovider.OrderAttributesRepository, + rbacRepo infoprovider.RBACRepository, + logger *slog.Logger, +) (*enforcer.Enforcer, error) { + // PRP: policy provider + policyProvider := filestore.New(policyPath) + + // PDP: decision maker + decisionMaker := decisionmaker.NewDecisionMaker( + policyProvider, + opa.NewEvaluator(RegoQuery), + decisionmaker.WithPolicyResolver(policyresolver.NewDefaultResolver(DefaultPolicyKey, PolicyVersion)), + decisionmaker.WithPolicyResolver(policyresolver.NewRBACResolver(RBACPolicyKey, PolicyVersion)), + ) + + // Context Handler: enrich request and call PDP + orchestrator := requestorchestrator.NewRequestOrchestrator( + []requestorchestrator.InfoAnalyser{ + infoanalyser.NewRBACAnalyser(infoprovider.InfoTypeRBAC), + }, + infoprovider.NewInfoProvider(map[infoprovider.InfoType]ip.InfoProvider{ + infoprovider.InfoTypeUser: infoprovider.NewUserProvider(userRepo), + infoprovider.InfoTypeOrder: infoprovider.NewOrderProvider(orderRepo), + infoprovider.InfoTypeRBAC: infoprovider.NewRoleBasedAccessProvider(rbacRepo), + }), + decisionMaker, + ) + + // HTTP request extractors for operations + orderCreateExtractor, err := operations.NewOrderExtractor(operations.ActionCreate) + if err != nil { + return nil, fmt.Errorf("new_order_create_extractor: %w", err) + } + + orderReadExtractor, err := operations.NewOrderExtractor( + operations.ActionRead, + operations.WithIDExtractor(operations.ExtractOrderIDFromPath), + ) + if err != nil { + return nil, fmt.Errorf("new_order_read_extractor: %w", err) + } + + // PEP request extractor + requestExtractor, err := enforcer.NewRequestExtractor( + enforcer.WithSubjectExtractor(jwt.NewSubjectExtractor()), + enforcer.WithOperationExtractor("/orders", http.MethodPost, orderCreateExtractor), + enforcer.WithOperationExtractor("/orders/*", http.MethodGet, orderReadExtractor), + ) + if err != nil { + return nil, fmt.Errorf("new_request_extractor: %w", err) + } + + // PEP middleware + return enforcer.NewEnforcer( + orchestrator, + requestExtractor, + logger, + enforcer.WithAdviceHandler("cache_hint", advice.NewCacheHintAdviceHandler(DecisionCacheHintHeaderName)), + enforcer.WithObligationHandler("audit_logging", obligation.NewAuditLogHandler(logger)), + ), nil +} + +// buildServeMux wires routes and applies the PEP to API endpoints. +func buildServeMux( + authHandler *handler.AuthHandler, + orderHandler *handler.OrderHandler, + jwtMiddleware *middleware.JWTAuthMiddleware, + enforcer *enforcer.Enforcer, +) *http.ServeMux { + root := http.NewServeMux() + root.Handle("GET /health", http.HandlerFunc(handleHealthCheck)) + + api := http.NewServeMux() + root.Handle("/api/v1/", http.StripPrefix("/api/v1", jwtMiddleware.Handler(enforcer.Enforce(api)))) + + root.Handle("POST /auth/signin", http.HandlerFunc(authHandler.SignIn)) + api.Handle("POST /orders", http.HandlerFunc(orderHandler.CreateOrder)) + api.Handle("GET /orders/{id}", http.HandlerFunc(orderHandler.GetOrderByID)) + return root +} + +// handleHealthCheck returns a basic health status. +func handleHealthCheck(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"healthy"}`)) +} + +// resolvePolicyPath prefers env var; falls back to executable dir. +func resolvePolicyPath(policyDir string, policyDirEnv string) (string, error) { + if policyPath := os.Getenv(policyDirEnv); policyPath != "" { + return policyPath, nil + } + + exe, err := os.Executable() + if err != nil { + return "", err + } + exeDir := filepath.Dir(exe) + return filepath.Join(exeDir, policyDir), nil +} diff --git a/examples/abac/cmd/api/policies/v1/default.rego b/examples/abac/cmd/api/policies/v1/default.rego new file mode 100644 index 0000000..ff588ab --- /dev/null +++ b/examples/abac/cmd/api/policies/v1/default.rego @@ -0,0 +1,41 @@ +package abac + +import data.abac.subject +import data.abac.resource + +# Top-level combiner: merges subject and resource results. +# Default: no applicable policy. +default result := { + "decision": "NotApplicable", + "status": { + "code": "PolicyNotFound", + "message": "no applicable policy was found for this request", + }, +} + +# Permit only if both subject and resource permit. +result = subject.result if { + subject.result.decision == "Permit" + resource.result.decision == "Permit" +} + +# If resource module not present, defer to subject. +result := subject.result if { + not resource.result +} + +# If subject denies (or not permit), return subject result. +result := subject.result if { + subject.result.decision != "Permit" +} + +# If subject module not present, defer to resource. +result := resource.result if { + not subject.result +} + +# If subject permits but resource does not, return resource result. +result := resource.result if { + subject.result.decision == "Permit" + resource.result.decision != "Permit" +} diff --git a/examples/abac/cmd/api/policies/v1/rbac.rego b/examples/abac/cmd/api/policies/v1/rbac.rego new file mode 100644 index 0000000..9f0ffc2 --- /dev/null +++ b/examples/abac/cmd/api/policies/v1/rbac.rego @@ -0,0 +1,64 @@ +package abac.subject + +# RBAC subject evaluation: finds an applicable permission for the subject. +# On Permit, returns cache advice and an audit obligation. +result := r if { + some role in input.environment.role_hierarchy.descendants + some permission in input.environment.role_permissions[role] + is_permission_applicable(permission, input) + + r = { + "decision": "Permit", + "status": {"code": "OK"}, + "advices": [{ + "id": "cache_hint", + "attributes": {"ttl_seconds": 30}, + }], + "obligations": [{ + "id": "audit_logging", + "attributes": { + "level": "INFO", + "message": sprintf("permit: subject=%s/%s action=%s resource=%s/%s", [input.subject.type, input.subject.id, input.action.id, input.resource.type, input.resource.id]), + }, + }], + } +} + +# Permission matches action/resource and all conditions are satisfied. +is_permission_applicable(permission, access_context) if { + permission.action == access_context.action.id + permission.resource == access_context.resource.type + + all_conditions_satisfied(object.get(permission, "conditions", {}), access_context) +} + +# No conditions means satisfied. +all_conditions_satisfied(conditions, _) if { + count(conditions) == 0 +} + +# At least one condition must be satisfied (OR semantics across conditions). +all_conditions_satisfied(conditions, access_context) if { + some condition in conditions + is_condition_satisfied(condition, access_context) +} + +# Evaluate a single condition. +is_condition_satisfied(condition, access_context) if { + actual_value = access_context.resource.attributes[condition.attribute_key] + expected_value = resolve_condition_attribute_value(condition.attribute_value, access_context) + apply_operator(condition.operator, actual_value, expected_value) +} + +# Supported operators. +apply_operator(operator, actual_value, expected_value) if { + operator == "equals" + actual_value == expected_value +} + +# Resolve ${...} references from access context; fallback to literal. +resolve_condition_attribute_value(attribute_value, access_context) := r if { + matches := regex.find_all_string_submatch_n(`^\${([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)*)}$`, attribute_value, 1) + count(matches) > 0 + r := object.get(access_context, split(matches[0][1], "."), "") +} else := attribute_value diff --git a/examples/abac/database/migrations/000001_create_abac_tables.down.sql b/examples/abac/database/migrations/000001_create_abac_tables.down.sql new file mode 100644 index 0000000..e394e28 --- /dev/null +++ b/examples/abac/database/migrations/000001_create_abac_tables.down.sql @@ -0,0 +1,9 @@ +DROP TABLE IF EXISTS orders; +DROP TABLE IF EXISTS role_permission_conditions; +DROP TABLE IF EXISTS role_permissions; +DROP TABLE IF EXISTS resources; +DROP TABLE IF EXISTS actions; +DROP TABLE IF EXISTS role_hierarchy; +DROP TABLE IF EXISTS user_roles; +DROP TABLE IF EXISTS roles; +DROP TABLE IF EXISTS users; diff --git a/examples/abac/database/migrations/000001_create_abac_tables.up.sql b/examples/abac/database/migrations/000001_create_abac_tables.up.sql new file mode 100644 index 0000000..bb1d446 --- /dev/null +++ b/examples/abac/database/migrations/000001_create_abac_tables.up.sql @@ -0,0 +1,86 @@ +CREATE TABLE users +( + id UUID PRIMARY KEY, + email VARCHAR(255) NOT NULL UNIQUE, + attributes JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX idx_users_roles ON users USING GIN ((attributes -> 'roles')); + +CREATE TABLE roles +( + id UUID PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE role_hierarchy +( + id UUID PRIMARY KEY, + parent_role_id UUID NOT NULL REFERENCES roles (id) ON DELETE CASCADE, + child_role_id UUID NOT NULL REFERENCES roles (id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + + UNIQUE (parent_role_id, child_role_id), + CHECK (parent_role_id != child_role_id) +); + +CREATE TABLE actions +( + id UUID PRIMARY KEY, + name VARCHAR(50) NOT NULL UNIQUE, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE resources +( + id UUID PRIMARY KEY, + name VARCHAR(50) NOT NULL UNIQUE, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE role_permissions +( + id UUID PRIMARY KEY, + role_id UUID NOT NULL REFERENCES roles (id) ON DELETE CASCADE, + action_id UUID NOT NULL REFERENCES actions (id) ON DELETE CASCADE, + resource_id UUID NOT NULL REFERENCES resources (id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + + UNIQUE (role_id, action_id, resource_id) +); + +CREATE TABLE role_permission_conditions +( + permission_id UUID NOT NULL REFERENCES role_permissions (id) ON DELETE CASCADE, + attribute_key VARCHAR(100) NOT NULL, + operator VARCHAR(20) NOT NULL, + attribute_value JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + + UNIQUE (permission_id, attribute_key, operator) +); + +CREATE TABLE orders +( + id UUID PRIMARY KEY, + name VARCHAR(50) NOT NULL UNIQUE, + attributes JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX idx_orders_user_id ON orders USING BTREE ((attributes ->> 'user_id')); +CREATE INDEX idx_orders_total_amount ON orders USING BTREE (((attributes ->> 'total_amount')::numeric)); +CREATE INDEX idx_orders_status ON orders USING BTREE ((attributes ->> 'status')); diff --git a/examples/abac/database/seeds/000001_seed_demo_data.down.sql b/examples/abac/database/seeds/000001_seed_demo_data.down.sql new file mode 100644 index 0000000..6a46b86 --- /dev/null +++ b/examples/abac/database/seeds/000001_seed_demo_data.down.sql @@ -0,0 +1,66 @@ +-- 000001_seed_demo_data.down.sql + +-- Remove demo orders +DELETE +FROM orders +WHERE name IN ('order-001', 'order-002'); + +-- Remove demo users +DELETE +FROM users +WHERE email IN ('alice@abac.com', 'bob@abac.com', 'cara@abac.com'); + +-- Remove customer conditions on order permissions +WITH cust_perms AS (SELECT rp.id + FROM role_permissions rp + JOIN roles r ON rp.role_id = r.id + JOIN actions a ON rp.action_id = a.id + JOIN resources res ON rp.resource_id = res.id + WHERE r.name = 'customer' + AND res.name = 'order' + AND a.name IN ('create', 'read')) +DELETE +FROM role_permission_conditions rpc + USING cust_perms cp +WHERE rpc.permission_id = cp.id + AND rpc.attribute_key = 'owner' + AND rpc.operator = 'equals'; + +-- Remove role_permissions inserted for demo (admin, customer_service, customer on order) +DELETE +FROM role_permissions rp + USING roles r, actions a, resources res +WHERE rp.role_id = r.id + AND rp.action_id = a.id + AND rp.resource_id = res.id + AND res.name = 'order' + AND r.name IN ('admin', 'customer_service', 'customer') + AND a.name IN ('create', 'read'); + +-- Remove role hierarchy links for demo +DELETE +FROM role_hierarchy rh + USING roles parent, roles child +WHERE rh.parent_role_id = parent.id + AND rh.child_role_id = child.id + AND ( + (parent.name = 'admin' AND child.name = 'customer_service') OR + (parent.name = 'customer_service' AND child.name = 'customer') + ); + +-- Remove demo roles (safe even if used elsewhere because we removed dependent rows) +DELETE +FROM roles +WHERE name IN ('admin', 'customer_service', 'customer'); + +-- Optionally remove actions and resource if they were created for demo +DELETE +FROM role_permissions rp +WHERE NOT EXISTS (SELECT 1 FROM roles r WHERE r.id = rp.role_id); -- safety cleanup + +DELETE +FROM actions +WHERE name IN ('create', 'read'); +DELETE +FROM resources +WHERE name = 'order'; \ No newline at end of file diff --git a/examples/abac/database/seeds/000001_seed_demo_data.up.sql b/examples/abac/database/seeds/000001_seed_demo_data.up.sql new file mode 100644 index 0000000..67cb58c --- /dev/null +++ b/examples/abac/database/seeds/000001_seed_demo_data.up.sql @@ -0,0 +1,103 @@ +-- 000001_seed_demo_data.up.sql + +-- Ensure required roles exist +INSERT INTO roles (id, name, description) +VALUES + (gen_random_uuid(), 'admin', 'Administrator role'), + (gen_random_uuid(), 'customer_service', 'Customer service role'), + (gen_random_uuid(), 'customer', 'Customer role') +ON CONFLICT (name) DO NOTHING; + +-- Role hierarchy: admin -> customer_service -> customer +INSERT INTO role_hierarchy (id, parent_role_id, child_role_id) +SELECT gen_random_uuid(), parent.id, child.id +FROM roles parent + JOIN roles child ON true +WHERE (parent.name = 'admin' AND child.name = 'customer_service') + OR (parent.name = 'customer_service' AND child.name = 'customer') +ON CONFLICT (parent_role_id, child_role_id) DO NOTHING; + +-- Actions and resource +INSERT INTO actions (id, name, description) +VALUES + (gen_random_uuid(), 'create', 'Create order'), + (gen_random_uuid(), 'read', 'Read order') +ON CONFLICT (name) DO NOTHING; + +INSERT INTO resources (id, name, description) +VALUES (gen_random_uuid(), 'order', 'Order resource') +ON CONFLICT (name) DO NOTHING; + +-- Role permissions +-- admin: create/read any order +INSERT INTO role_permissions (id, role_id, action_id, resource_id) +SELECT gen_random_uuid(), r.id, a.id, res.id +FROM roles r, actions a, resources res +WHERE r.name = 'admin' AND res.name = 'order' AND a.name IN ('create','read') +ON CONFLICT (role_id, action_id, resource_id) DO NOTHING; + +-- customer_service: read any order +INSERT INTO role_permissions (id, role_id, action_id, resource_id) +SELECT gen_random_uuid(), r.id, a.id, res.id +FROM roles r, actions a, resources res +WHERE r.name = 'customer_service' AND res.name = 'order' AND a.name = 'read' +ON CONFLICT (role_id, action_id, resource_id) DO NOTHING; + +-- customer: create/read only own order (add conditions) +INSERT INTO role_permissions (id, role_id, action_id, resource_id) +SELECT gen_random_uuid(), r.id, a.id, res.id +FROM roles r, actions a, resources res +WHERE r.name = 'customer' AND res.name = 'order' AND a.name IN ('create','read') +ON CONFLICT (role_id, action_id, resource_id) DO NOTHING; + +-- Attach conditions to the customer's permissions +-- owner == ${subject.id} +INSERT INTO role_permission_conditions (permission_id, attribute_key, operator, attribute_value) +SELECT rp.id, 'owner', 'equals', '"${subject.id}"'::jsonb +FROM role_permissions rp + JOIN roles r ON rp.role_id = r.id + JOIN actions a ON rp.action_id = a.id + JOIN resources res ON rp.resource_id = res.id +WHERE r.name = 'customer' AND res.name = 'order' AND a.name IN ('create','read') +ON CONFLICT (permission_id, attribute_key, operator) DO NOTHING; + +-- Demo users +INSERT INTO users (id, email, attributes) +VALUES + (gen_random_uuid(), 'alice@abac.com', '{"roles":["admin"],"department":"operations","region":"global"}'), + (gen_random_uuid(), 'bob@abac.com', '{"roles":["customer_service"],"department":"support","region":"na"}'), + (gen_random_uuid(), 'cara@abac.com', '{"roles":["customer"],"department":"consumer","region":"eu"}') +ON CONFLICT (email) DO NOTHING; + +-- Demo orders (owned by cara) +INSERT INTO role_permission_conditions (permission_id, attribute_key, operator, attribute_value) +SELECT rp.id, 'owner', 'equals', '"${subject.id}"'::jsonb +FROM role_permissions rp + JOIN roles r ON rp.role_id = r.id + JOIN actions a ON rp.action_id = a.id + JOIN resources res ON rp.resource_id = res.id +WHERE r.name = 'customer' AND res.name = 'order' AND a.name IN ('create','read') +ON CONFLICT (permission_id, attribute_key, operator) DO NOTHING; + +-- Demo users +INSERT INTO users (id, email, attributes) +VALUES + (gen_random_uuid(), 'alice@abac.com', '{"roles":["admin"],"department":"operations","region":"global"}'), + (gen_random_uuid(), 'bob@abac.com', '{"roles":["customer_service"],"department":"support","region":"na"}'), + (gen_random_uuid(), 'cara@abac.com', '{"roles":["customer"],"department":"consumer","region":"eu"}') +ON CONFLICT (email) DO NOTHING; + +-- Demo orders (owned by cara - owner is user ID) +INSERT INTO orders (id, name, attributes) +VALUES + (gen_random_uuid(), 'order-001', jsonb_build_object( + 'owner', (SELECT id FROM users WHERE email = 'bob@abac.com'), + 'total_amount', '123.45', + 'status', 'created' + )), + (gen_random_uuid(), 'order-002', jsonb_build_object( + 'owner', (SELECT id FROM users WHERE email = 'cara@abac.com'), + 'total_amount', '42.00', + 'status', 'created' + )) +ON CONFLICT (name) DO NOTHING; diff --git a/examples/abac/go.mod b/examples/abac/go.mod new file mode 100644 index 0000000..80d261f --- /dev/null +++ b/examples/abac/go.mod @@ -0,0 +1,59 @@ +module github.com/CameronXie/access-control-explorer/examples/abac + +go 1.24.4 + +require ( + github.com/CameronXie/access-control-explorer v0.0.0 + golang.org/x/sync v0.16.0 +) + +require ( + github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/google/uuid v1.6.0 + github.com/gorilla/mux v1.8.1 + github.com/jackc/pgx/v5 v5.7.5 + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/agnivade/levenshtein v1.2.1 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-ini/ini v1.67.0 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/gobwas/glob v0.2.3 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/open-policy-agent/opa v1.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_golang v1.21.0 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/stretchr/objx v0.5.2 // indirect + github.com/tchap/go-patricia/v2 v2.3.2 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/yashtewari/glob-intersection v0.2.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel v1.34.0 // indirect + go.opentelemetry.io/otel/metric v1.34.0 // indirect + go.opentelemetry.io/otel/sdk v1.34.0 // indirect + go.opentelemetry.io/otel/trace v1.34.0 // indirect + golang.org/x/crypto v0.37.0 // indirect + golang.org/x/sys v0.32.0 // indirect + golang.org/x/text v0.24.0 // indirect + golang.org/x/tools v0.26.0 // indirect + google.golang.org/protobuf v1.36.3 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + sigs.k8s.io/yaml v1.4.0 // indirect +) + +replace github.com/CameronXie/access-control-explorer => ../../ diff --git a/examples/abac/go.sum b/examples/abac/go.sum new file mode 100644 index 0000000..633d6d6 --- /dev/null +++ b/examples/abac/go.sum @@ -0,0 +1,157 @@ +github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= +github.com/agnivade/levenshtein v1.2.1/go.mod h1:QVVI16kDrtSuwcpd0p1+xMC6Z/VfhtCyDIjcwga4/DU= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA= +github.com/bytecodealliance/wasmtime-go/v3 v3.0.2/go.mod h1:RnUjnIXxEJcL6BgCvNyzCCRzZcxCgsZCi+RNlvYor5Q= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/badger/v4 v4.5.1 h1:7DCIXrQjo1LKmM96YD+hLVJ2EEsyyoWxJfpdd56HLps= +github.com/dgraph-io/badger/v4 v4.5.1/go.mod h1:qn3Be0j3TfV4kPbVoK0arXCD1/nr1ftth6sbL5jxdoA= +github.com/dgraph-io/ristretto/v2 v2.1.0 h1:59LjpOJLNDULHh8MC4UaegN52lC4JnO2dITsie/Pa8I= +github.com/dgraph-io/ristretto/v2 v2.1.0/go.mod h1:uejeqfYXpUomfse0+lO+13ATz4TypQYLJZzBSAemuB4= +github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54 h1:SG7nF6SRlWhcT7cNTs5R6Hk4V2lcmLz2NsG2VnInyNo= +github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= +github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI= +github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= +github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/google/flatbuffers v24.12.23+incompatible h1:ubBKR94NR4pXUCY/MUsRVzd9umNW7ht7EG9hHfS9FX8= +github.com/google/flatbuffers v24.12.23+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3ArSgIyScOAyMRqBxRg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs= +github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM= +github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/open-policy-agent/opa v1.2.0 h1:88NDVCM0of1eO6Z4AFeL3utTEtMuwloFmWWU7dRV1z0= +github.com/open-policy-agent/opa v1.2.0/go.mod h1:30euUmOvuBoebRCcJ7DMF42bRBOPznvt0ACUMYDUGVY= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA= +github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ= +github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tchap/go-patricia/v2 v2.3.2 h1:xTHFutuitO2zqKAQ5rCROYgUb7Or/+IC3fts9/Yc7nM= +github.com/tchap/go-patricia/v2 v2.3.2/go.mod h1:VZRHKAb53DLaG+nA9EaYYiaEx6YztwDlLElMsnSHD4k= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/yashtewari/glob-intersection v0.2.0 h1:8iuHdN88yYuCzCdjt0gDe+6bAhUwBeEWqThExu54RFg= +github.com/yashtewari/glob-intersection v0.2.0/go.mod h1:LK7pIC3piUjovexikBbJ26Yml7g8xa5bsjfx2v1fwok= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 h1:CV7UdSGJt/Ao6Gp4CXckLxVRRsRgDHoI8XjbL3PDl8s= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0/go.mod h1:FRmFuRJfag1IZ2dPkHnEoSFVgTVPUd2qf5Vi69hLb8I= +go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= +go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 h1:OeNbIYk/2C15ckl7glBlOBp5+WlYsOElzTNmiPW/x60= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0/go.mod h1:7Bept48yIeqxP2OZ9/AqIpYS94h2or0aB4FypJTc8ZM= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0 h1:tgJ0uaNS4c98WRNUEx5U3aDlrDOI5Rs+1Vifcw4DJ8U= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0/go.mod h1:U7HYyW0zt/a9x5J1Kjs+r1f/d4ZHnYFclhYY2+YbeoE= +go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= +go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= +go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= +go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= +go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= +go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= +go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4= +go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= +golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f h1:gap6+3Gk41EItBuyi4XX/bp4oqJ3UwuIMl25yGinuAA= +google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:Ic02D47M+zbarjYYUlK57y316f2MoN0gjAwI3f2S95o= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= +google.golang.org/grpc v1.70.0 h1:pWFv03aZoHzlRKHWicjsZytKAiYCtNS0dHbXnIdq7jQ= +google.golang.org/grpc v1.70.0/go.mod h1:ofIJqVKDXx/JiXrwr2IG4/zwdH9txy3IlF40RmcJSQw= +google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU= +google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= +sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= diff --git a/examples/abac/internal/advice/cache_hint_handler.go b/examples/abac/internal/advice/cache_hint_handler.go new file mode 100644 index 0000000..bb91362 --- /dev/null +++ b/examples/abac/internal/advice/cache_hint_handler.go @@ -0,0 +1,69 @@ +package advice + +import ( + "context" + "fmt" + "net/http" + "strconv" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" +) + +// CacheHintAdviceHandler sets a response header with the suggested TTL (in seconds) +// read from the "ttl_seconds" attribute in the "cache_hint" advice. +// By default, it writes "X-ABAC-Decision-TTL" header. You can override the header +// name via the constructor. +type CacheHintAdviceHandler struct { + HeaderName string +} + +// NewCacheHintAdviceHandler creates a new CacheHintAdviceHandler. +// If headerName is empty, it defaults to "X-ABAC-Decision-TTL". +func NewCacheHintAdviceHandler(headerName string) *CacheHintAdviceHandler { + if headerName == "" { + headerName = "X-ABAC-Decision-TTL" + } + return &CacheHintAdviceHandler{HeaderName: headerName} +} + +func (h *CacheHintAdviceHandler) Handle(_ context.Context, advice ro.Advice, w http.ResponseWriter, _ *http.Request) error { + // Expect attribute "ttl_seconds" + raw, ok := advice.Attributes["ttl_seconds"] + if !ok { + return fmt.Errorf("cache_hint advice missing 'ttl_seconds' attribute") + } + + ttl, err := toInt(raw) + if err != nil { + return fmt.Errorf("cache_hint invalid 'ttl_seconds': %w", err) + } + if ttl <= 0 { + return fmt.Errorf("cache_hint 'ttl_seconds' must be > 0, got %d", ttl) + } + + w.Header().Set(h.HeaderName, strconv.Itoa(ttl)) + return nil +} + +func toInt(v any) (int, error) { + switch t := v.(type) { + case int: + return t, nil + case int32: + return int(t), nil + case int64: + return int(t), nil + case float32: + return int(t), nil + case float64: + return int(t), nil + case string: + n, err := strconv.Atoi(t) + if err != nil { + return 0, fmt.Errorf("not a number: %v", t) + } + return n, nil + default: + return 0, fmt.Errorf("unsupported type %T", v) + } +} diff --git a/examples/abac/internal/advice/cache_hint_handler_test.go b/examples/abac/internal/advice/cache_hint_handler_test.go new file mode 100644 index 0000000..56f823c --- /dev/null +++ b/examples/abac/internal/advice/cache_hint_handler_test.go @@ -0,0 +1,88 @@ +package advice + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/stretchr/testify/assert" +) + +func TestCacheHintAdviceHandler(t *testing.T) { + tests := map[string]struct { + headerName string + adviceAttributes map[string]any + expectedHeaderName string + expectedHeaderVal string + expectError bool + }{ + "should set default header with int ttl": { + headerName: "", + adviceAttributes: map[string]any{ + "ttl_seconds": 30, + }, + expectedHeaderName: "X-ABAC-Decision-TTL", + expectedHeaderVal: "30", + expectError: false, + }, + "should set custom header with string ttl": { + headerName: "X-Custom-TTL", + adviceAttributes: map[string]any{ + "ttl_seconds": "45", + }, + expectedHeaderName: "X-Custom-TTL", + expectedHeaderVal: "45", + expectError: false, + }, + "should error when ttl is missing": { + headerName: "", + adviceAttributes: map[string]any{}, + expectedHeaderName: "X-ABAC-Decision-TTL", + expectedHeaderVal: "", + expectError: true, + }, + "should error when ttl type is invalid": { + headerName: "", + adviceAttributes: map[string]any{ + "ttl_seconds": []int{10}, + }, + expectedHeaderName: "X-ABAC-Decision-TTL", + expectedHeaderVal: "", + expectError: true, + }, + "should error when ttl is non-positive": { + headerName: "", + adviceAttributes: map[string]any{ + "ttl_seconds": 0, + }, + expectedHeaderName: "X-ABAC-Decision-TTL", + expectedHeaderVal: "", + expectError: true, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + h := NewCacheHintAdviceHandler(tc.headerName) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody) + + advice := ro.Advice{ + ID: "cache_hint", + Attributes: tc.adviceAttributes, + } + + err := h.Handle(context.Background(), advice, rr, req) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, tc.expectedHeaderVal, rr.Header().Get(tc.expectedHeaderName)) + }) + } +} diff --git a/examples/abac/internal/api/rest/handler/auth_handler.go b/examples/abac/internal/api/rest/handler/auth_handler.go new file mode 100644 index 0000000..b9e0012 --- /dev/null +++ b/examples/abac/internal/api/rest/handler/auth_handler.go @@ -0,0 +1,145 @@ +// Package handler provides HTTP handlers for authentication. +// WARNING: This signin handler is for demo purposes only and should NOT be used in production. +// It lacks proper password validation, rate limiting, and other security measures. +package handler + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "net/http" + "time" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/CameronXie/access-control-explorer/examples/abac/pkg/keyfetcher" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +// UserRepository defines the interface for user data access +type UserRepository interface { + GetUserIDByEmail(ctx context.Context, email string) (uuid.UUID, error) +} + +// AuthConfig holds authentication configuration +type AuthConfig struct { + KeyFetcher keyfetcher.PrivateKeyFetcher + Issuer string + Audience string + TokenTTL time.Duration +} + +// AuthHandler handles authentication requests +type AuthHandler struct { + userRepo UserRepository + config *AuthConfig + logger *slog.Logger +} + +// SignInRequest represents the signin request payload +type SignInRequest struct { + Email string `json:"email"` +} + +// SignInResponse represents the signin response payload +type SignInResponse struct { + Token string `json:"token"` + TokenType string `json:"token_type"` +} + +// JWTClaims contains minimal JWT claims for demo purposes +type JWTClaims struct { + jwt.RegisteredClaims +} + +// NewAuthHandler creates a new authentication handler +func NewAuthHandler(userRepo UserRepository, config *AuthConfig, logger *slog.Logger) *AuthHandler { + return &AuthHandler{ + userRepo: userRepo, + config: config, + logger: logger, + } +} + +// SignIn handles user signin requests +// WARNING: Demo implementation - lacks password verification and security measures +func (h *AuthHandler) SignIn(w http.ResponseWriter, r *http.Request) { + // Parse request body + var req SignInRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.logger.Warn("Invalid request format", "error", err) + WriteErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid request format") + return + } + + // Validate email + if req.Email == "" { + h.logger.Warn("Sign in attempt with empty email") + WriteErrorResponse(w, http.StatusBadRequest, "invalid_request", "Email is required") + return + } + + // Look up user ID by email + userID, err := h.userRepo.GetUserIDByEmail(r.Context(), req.Email) + if err != nil { + // Check if it's a not found error + var notFoundErr *repository.NotFoundError + if errors.As(err, ¬FoundErr) { + h.logger.Warn("Sign in attempt for non-existent user", "email", req.Email) + } else { + h.logger.Error("Failed to retrieve user during sign in", "email", req.Email, "error", err) + } + WriteErrorResponse(w, http.StatusUnauthorized, "authentication_failed", "Authentication failed") + return + } + + // Generate JWT token + token, err := h.generateJWT(userID) + if err != nil { + h.logger.Error("Failed to generate JWT token", "user_id", userID, "error", err) + WriteErrorResponse(w, http.StatusInternalServerError, "authentication_failed", "Authentication failed") + return + } + + h.logger.Info("Successful user sign in", "user_id", userID, "email", req.Email) + + // Return successful response + response := SignInResponse{ + Token: token, + TokenType: "Bearer", + } + + WriteJSONResponse(w, http.StatusOK, response) +} + +// generateJWT creates a JWT token for the authenticated user +func (h *AuthHandler) generateJWT(userID uuid.UUID) (string, error) { + // Fetch private key using keyfetcher + privateKey, err := h.config.KeyFetcher.FetchPrivateKey() + if err != nil { + return "", err + } + + now := time.Now() + expiresAt := now.Add(h.config.TokenTTL) + + // Create JWT claims with minimal required fields + claims := JWTClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: h.config.Issuer, + Subject: userID.String(), + Audience: jwt.ClaimStrings{h.config.Audience}, + ExpiresAt: jwt.NewNumericDate(expiresAt), + }, + } + + // Sign and return the token + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(privateKey) + if err != nil { + return "", err + } + + return tokenString, nil +} diff --git a/examples/abac/internal/api/rest/handler/auth_handler_test.go b/examples/abac/internal/api/rest/handler/auth_handler_test.go new file mode 100644 index 0000000..2ddd46c --- /dev/null +++ b/examples/abac/internal/api/rest/handler/auth_handler_test.go @@ -0,0 +1,280 @@ +package handler + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// mockUserRepository is a mock implementation of the UserRepository interface. +type mockUserRepository struct { + mock.Mock +} + +func (m *mockUserRepository) GetUserIDByEmail(ctx context.Context, email string) (uuid.UUID, error) { + args := m.Called(ctx, email) + return args.Get(0).(uuid.UUID), args.Error(1) +} + +// mockPrivateKeyFetcher is a mock implementation of the PrivateKeyFetcher interface. +type mockPrivateKeyFetcher struct { + mock.Mock +} + +func (m *mockPrivateKeyFetcher) FetchPrivateKey() (*rsa.PrivateKey, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*rsa.PrivateKey), args.Error(1) +} + +// Helper function to generate a fake RSA private key. +func generateFakeRSAPrivateKey() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, 2048) +} + +// Helper function to create a test user ID. +func createTestUserID() uuid.UUID { + return uuid.New() +} + +func TestAuthHandler_SignIn(t *testing.T) { + testUserID := createTestUserID() + + cases := map[string]struct { + requestBody string + mockUserResult uuid.UUID + mockUserError error + mockKeyError error + expectedStatus int + expectedMessage string + expectedLogMessage string + expectedLogLevel slog.Level + }{ + "should Return 200 and Token on Successful Authentication": { + requestBody: `{"email": "test@example.com"}`, + mockUserResult: testUserID, + mockUserError: nil, + mockKeyError: nil, + expectedStatus: http.StatusOK, + expectedLogMessage: "Successful user sign in", + expectedLogLevel: slog.LevelInfo, + }, + + "should Return 400 on Invalid Request Body": { + requestBody: "invalid", + mockUserResult: uuid.Nil, + mockUserError: nil, + mockKeyError: nil, + expectedStatus: http.StatusBadRequest, + expectedMessage: "Invalid request format", + expectedLogMessage: "Invalid request format", + expectedLogLevel: slog.LevelWarn, + }, + + "should Return 400 on Missing Email": { + requestBody: `{"email": ""}`, + mockUserResult: uuid.Nil, + mockUserError: nil, + mockKeyError: nil, + expectedStatus: http.StatusBadRequest, + expectedMessage: "Email is required", + expectedLogMessage: "Sign in attempt with empty email", + expectedLogLevel: slog.LevelWarn, + }, + + "should Return 401 on User Not Found": { + requestBody: `{"email": "notfound@example.com"}`, + mockUserResult: uuid.Nil, + mockUserError: &repository.NotFoundError{Resource: "user", Key: "email", Value: "notfound@example.com"}, + mockKeyError: nil, + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Authentication failed", + expectedLogMessage: "Sign in attempt for non-existent user", + expectedLogLevel: slog.LevelWarn, + }, + + "should Return 401 on Database Error": { + requestBody: `{"email": "error@example.com"}`, + mockUserResult: uuid.Nil, + mockUserError: errors.New("database connection failed"), + mockKeyError: nil, + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Authentication failed", + expectedLogMessage: "Failed to retrieve user during sign in", + expectedLogLevel: slog.LevelError, + }, + + "should Return 500 on Key Fetch Failure": { + requestBody: `{"email": "test@example.com"}`, + mockUserResult: testUserID, + mockUserError: nil, + mockKeyError: errors.New("key fetch failed"), + expectedStatus: http.StatusInternalServerError, + expectedMessage: "Authentication failed", + expectedLogMessage: "Failed to generate JWT token", + expectedLogLevel: slog.LevelError, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + mockUserRepo := new(mockUserRepository) + mockKeyFetcher := new(mockPrivateKeyFetcher) + testLogger := newTestLogger() + logger := testLogger.getLogger() + + config := &AuthConfig{ + KeyFetcher: mockKeyFetcher, + Issuer: "test-issuer", + TokenTTL: time.Hour, + Audience: "test-audience", + } + + handler := NewAuthHandler(mockUserRepo, config, logger) + + if tc.mockUserResult != uuid.Nil || tc.mockUserError != nil { + mockUserRepo.On("GetUserIDByEmail", mock.Anything, mock.Anything).Return(tc.mockUserResult, tc.mockUserError) + } + + var privateKey *rsa.PrivateKey + if tc.mockKeyError == nil { + key, err := generateFakeRSAPrivateKey() + if err != nil { + t.Fatalf("failed to generate fake RSA private key: %v", err) + } + privateKey = key + } + if tc.mockUserResult != uuid.Nil && tc.mockUserError == nil { + mockKeyFetcher.On("FetchPrivateKey").Return(privateKey, tc.mockKeyError) + } + + w := httptest.NewRecorder() + r := httptest.NewRequest( + http.MethodPost, + "/auth/signin", + bytes.NewBufferString(tc.requestBody), + ) + + handler.SignIn(w, r) + + // Assert HTTP response + assert.Equal(t, tc.expectedStatus, w.Code) + + body := w.Body.String() + if tc.expectedStatus == http.StatusOK { + assert.Contains(t, body, "token") + assert.Contains(t, body, "Bearer") + } else { + assert.Contains(t, body, tc.expectedMessage) + } + + // Assert log messages and levels + if tc.expectedLogMessage != "" { + assert.NotEmpty(t, testLogger.messages, "Expected log message but no logs were captured") + + // Check if the expected message exists in any of the captured messages + for i, message := range testLogger.messages { + assert.Contains(t, message, tc.expectedLogMessage) + assert.Equal(t, tc.expectedLogLevel, testLogger.levels[i]) + } + } + + // Verify that mocks were called as expected + mockUserRepo.AssertExpectations(t) + mockKeyFetcher.AssertExpectations(t) + testLogger.reset() + }) + } +} + +func TestAuthHandler_generateJWT(t *testing.T) { + cases := map[string]struct { + keyFetchError error + expectedError bool + validateClaims bool + }{ + "should Generate Valid JWT Token": { + keyFetchError: nil, + expectedError: false, + validateClaims: true, + }, + + "should Return Error on Key Fetch Failure": { + keyFetchError: errors.New("key fetch failed"), + expectedError: true, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + mockKeyFetcher := new(mockPrivateKeyFetcher) + testLogger := newTestLogger() + logger := testLogger.getLogger() + + config := &AuthConfig{ + KeyFetcher: mockKeyFetcher, + Issuer: "test-issuer", + TokenTTL: time.Hour, + Audience: "test-audience", + } + + handler := NewAuthHandler(nil, config, logger) + userID := createTestUserID() + + var privateKey *rsa.PrivateKey + if tc.keyFetchError == nil { + key, err := generateFakeRSAPrivateKey() + if err != nil { + t.Fatalf("failed to generate fake RSA private key: %v", err) + } + privateKey = key + } + + mockKeyFetcher.On("FetchPrivateKey").Return(privateKey, tc.keyFetchError) + + token, err := handler.generateJWT(userID) + + if tc.expectedError { + assert.Error(t, err) + assert.Empty(t, token) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, token) + } + + mockKeyFetcher.AssertExpectations(t) + }) + } +} + +func TestNewAuthHandler(t *testing.T) { + mockUserRepo := new(mockUserRepository) + testLogger := newTestLogger() + logger := testLogger.getLogger() + config := &AuthConfig{ + Issuer: "test-issuer", + TokenTTL: time.Hour, + Audience: "test-audience", + } + + handler := NewAuthHandler(mockUserRepo, config, logger) + + assert.NotNil(t, handler) + assert.Equal(t, mockUserRepo, handler.userRepo) + assert.Equal(t, config, handler.config) + assert.Equal(t, logger, handler.logger) +} diff --git a/examples/abac/internal/api/rest/handler/order_handler.go b/examples/abac/internal/api/rest/handler/order_handler.go new file mode 100644 index 0000000..07cffee --- /dev/null +++ b/examples/abac/internal/api/rest/handler/order_handler.go @@ -0,0 +1,145 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "net/http" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/api/rest/middleware" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/domain" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/google/uuid" +) + +const ( + OrderStatusCreated = "created" +) + +// OrderRepository defines the interface for order repository operations +type OrderRepository interface { + CreateOrder(ctx context.Context, order *domain.Order) error + GetOrderByID(ctx context.Context, id uuid.UUID) (*domain.Order, error) +} + +// OrderHandler handles HTTP requests for order operations +type OrderHandler struct { + repo OrderRepository + logger *slog.Logger +} + +// NewOrderHandler creates a new OrderHandler instance +func NewOrderHandler(repo OrderRepository, logger *slog.Logger) *OrderHandler { + return &OrderHandler{ + repo: repo, + logger: logger, + } +} + +// CreateOrderRequest represents the request payload for creating an order +type CreateOrderRequest struct { + Name string `json:"name" validate:"required"` + Attributes map[string]any `json:"attributes,omitempty"` +} + +// CreateOrderResponse represents the response for creating an order +type CreateOrderResponse struct { + ID uuid.UUID `json:"id"` + Name string `json:"name"` + Attributes map[string]any `json:"attributes"` +} + +// CreateOrder handles POST /orders - creates a new order +func (h *OrderHandler) CreateOrder(w http.ResponseWriter, r *http.Request) { + var req CreateOrderRequest + + // Parse request body + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + WriteErrorResponse(w, http.StatusBadRequest, "Invalid request body", err.Error()) + return + } + + // Validate required fields + if req.Name == "" { + WriteErrorResponse(w, http.StatusBadRequest, "Invalid request", "Name is required") + return + } + + // Get user ID from context (set by JWT middleware) + userID, ok := middleware.GetUserIDFromContext(r.Context()) + if !ok { + h.logger.Error("User ID not found in context") + WriteErrorResponse(w, http.StatusUnauthorized, "Authentication required", "User authentication is required") + return + } + + // Initialize attributes if nil + if req.Attributes == nil { + req.Attributes = make(map[string]any) + } + + // Set owner (user_id) and status in attributes + req.Attributes["owner"] = userID + req.Attributes["status"] = OrderStatusCreated + + // Create order domain model + order := &domain.Order{ + ID: uuid.New(), + Name: req.Name, + Attributes: req.Attributes, + } + + // Save to database + if err := h.repo.CreateOrder(r.Context(), order); err != nil { + h.logger.Error("Failed to create order", "error", err, "order_name", order.Name, "user_id", userID) + WriteErrorResponse( + w, + http.StatusInternalServerError, + "Failed to create order", + "An internal error occurred while processing your request", + ) + return + } + + // Return success response + response := CreateOrderResponse{ + ID: order.ID, + Name: order.Name, + Attributes: order.Attributes, + } + + WriteJSONResponse(w, http.StatusCreated, response) +} + +// GetOrderByID handles GET /orders/{id} - retrieves an order by ID +func (h *OrderHandler) GetOrderByID(w http.ResponseWriter, r *http.Request) { + // Extract ID from URL path + idStr := r.PathValue("id") + + // Parse UUID + id, err := uuid.Parse(idStr) + if err != nil { + WriteErrorResponse(w, http.StatusBadRequest, "Invalid order ID", "ID must be a valid UUID") + return + } + + // Get order from database + order, err := h.repo.GetOrderByID(r.Context(), id) + if err != nil { + // Check if it's a not found error using errors.As + var notFoundErr *repository.NotFoundError + if errors.As(err, ¬FoundErr) { + h.logger.Warn("Order not found", "order_id", id, "error", err) + WriteErrorResponse(w, http.StatusNotFound, "Order not found", "The requested order could not be found") + return + } + + h.logger.Error("Failed to retrieve order", "order_id", id, "error", err) + WriteErrorResponse(w, http.StatusInternalServerError, "Failed to retrieve order", "An internal error occurred while retrieving the order") + return + } + + // Return order + WriteJSONResponse(w, http.StatusOK, order) +} diff --git a/examples/abac/internal/api/rest/handler/order_handler_test.go b/examples/abac/internal/api/rest/handler/order_handler_test.go new file mode 100644 index 0000000..d7bb0a1 --- /dev/null +++ b/examples/abac/internal/api/rest/handler/order_handler_test.go @@ -0,0 +1,483 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/api/rest/middleware" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/domain" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type mockOrderRepository struct { + mock.Mock +} + +func (m *mockOrderRepository) CreateOrder(ctx context.Context, order *domain.Order) error { + args := m.Called(ctx, order) + return args.Error(0) +} + +func (m *mockOrderRepository) GetOrderByID(ctx context.Context, id uuid.UUID) (*domain.Order, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Order), args.Error(1) +} + +// testLogger captures log messages and levels for testing +type testLogger struct { + messages []string + levels []slog.Level + buffer *bytes.Buffer +} + +func newTestLogger() *testLogger { + buffer := &bytes.Buffer{} + return &testLogger{ + messages: make([]string, 0), + levels: make([]slog.Level, 0), + buffer: buffer, + } +} + +func (tl *testLogger) getLogger() *slog.Logger { + handler := slog.NewTextHandler(tl.buffer, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + + // Create a custom handler that captures messages and levels + return slog.New(&captureHandler{ + testLogger: tl, + handler: handler, + }) +} + +func (tl *testLogger) reset() { + tl.messages = tl.messages[:0] + tl.levels = tl.levels[:0] + tl.buffer.Reset() +} + +// captureHandler wraps the original handler to capture log data +type captureHandler struct { + testLogger *testLogger + handler slog.Handler +} + +func (ch *captureHandler) Enabled(ctx context.Context, level slog.Level) bool { + return ch.handler.Enabled(ctx, level) +} + +func (ch *captureHandler) Handle(ctx context.Context, record slog.Record) error { //nolint:gocritic // slog.Handler interface + // Capture the message and level + ch.testLogger.messages = append(ch.testLogger.messages, record.Message) + ch.testLogger.levels = append(ch.testLogger.levels, record.Level) + + // Also call the original handler + return ch.handler.Handle(ctx, record) +} + +func (ch *captureHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &captureHandler{ + testLogger: ch.testLogger, + handler: ch.handler.WithAttrs(attrs), + } +} + +func (ch *captureHandler) WithGroup(name string) slog.Handler { + return &captureHandler{ + testLogger: ch.testLogger, + handler: ch.handler.WithGroup(name), + } +} + +type testCreateOrderInput struct { + requestBody map[string]any + userID string + hasUserInContext bool + mockCreateOrderError error +} + +type testGetOrderInput struct { + orderID string + mockOrder *domain.Order + mockGetOrderError error +} + +// Helper function to create context with user ID +func createContextWithUserID(userID string) context.Context { + return context.WithValue(context.Background(), middleware.UserIDContextKey, userID) +} + +func TestOrderHandler_CreateOrder(t *testing.T) { + testUserID := "test-user-123" + + testCases := map[string]struct { + input testCreateOrderInput + expectedStatus int + expectedError string + expectedLogMessage string + expectedLogLevel slog.Level + }{ + "should create order successfully with valid request and user context": { + input: testCreateOrderInput{ + requestBody: map[string]any{ + "name": "Test Order", + "attributes": map[string]any{ + "priority": "high", + "category": "electronics", + }, + }, + userID: testUserID, + hasUserInContext: true, + mockCreateOrderError: nil, + }, + expectedStatus: http.StatusCreated, + }, + + "should create order successfully with minimal request and set attributes": { + input: testCreateOrderInput{ + requestBody: map[string]any{ + "name": "Minimal Order", + }, + userID: testUserID, + hasUserInContext: true, + mockCreateOrderError: nil, + }, + expectedStatus: http.StatusCreated, + }, + + "should return unauthorized when user ID not in context": { + input: testCreateOrderInput{ + requestBody: map[string]any{ + "name": "Test Order", + }, + hasUserInContext: false, + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "User authentication is required", + expectedLogMessage: "User ID not found in context", + expectedLogLevel: slog.LevelError, + }, + + "should return bad request when name is missing": { + input: testCreateOrderInput{ + requestBody: map[string]any{ + "attributes": map[string]any{ + "category": "electronics", + }, + }, + userID: testUserID, + hasUserInContext: true, + }, + expectedStatus: http.StatusBadRequest, + expectedError: "Name is required", + }, + + "should return bad request when name is empty": { + input: testCreateOrderInput{ + requestBody: map[string]any{ + "name": "", + "attributes": map[string]any{}, + }, + userID: testUserID, + hasUserInContext: true, + }, + expectedStatus: http.StatusBadRequest, + expectedError: "Name is required", + }, + + "should return internal server error when repository fails": { + input: testCreateOrderInput{ + requestBody: map[string]any{ + "name": "Failed Order", + }, + userID: testUserID, + hasUserInContext: true, + mockCreateOrderError: errors.New("database connection failed"), + }, + expectedStatus: http.StatusInternalServerError, + expectedError: "An internal error occurred while processing your request", + expectedLogMessage: "Failed to create order", + expectedLogLevel: slog.LevelError, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Initialize mock and test logger + mockRepo := &mockOrderRepository{} + testLogger := newTestLogger() + logger := testLogger.getLogger() + handler := NewOrderHandler(mockRepo, logger) + + // Setup mock behavior based on input + if tc.input.requestBody["name"] != "" && tc.input.requestBody["name"] != nil && tc.input.hasUserInContext { + mockRepo.On("CreateOrder", mock.Anything, mock.MatchedBy(func(order *domain.Order) bool { + // Verify that owner and status are set correctly + return order.Name == tc.input.requestBody["name"] && + order.Attributes["owner"] == tc.input.userID && + order.Attributes["status"] == OrderStatusCreated + })).Return(tc.input.mockCreateOrderError) + } + + // Prepare request + requestBody, _ := json.Marshal(tc.input.requestBody) + req := httptest.NewRequest(http.MethodPost, "/orders", bytes.NewBuffer(requestBody)) + req.Header.Set("Content-Type", "application/json") + + // Set user ID in context if needed + if tc.input.hasUserInContext { + ctx := createContextWithUserID(tc.input.userID) + req = req.WithContext(ctx) + } + + w := httptest.NewRecorder() + + // Execute + handler.CreateOrder(w, req) + + // Assert HTTP response + assert.Equal(t, tc.expectedStatus, w.Code) + + if tc.expectedError != "" { + var errorResponse ErrorResponse + err := json.Unmarshal(w.Body.Bytes(), &errorResponse) + assert.NoError(t, err) + assert.Contains(t, errorResponse.Message, tc.expectedError) + } else { + var response CreateOrderResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.NotEmpty(t, response.ID) + assert.Equal(t, tc.input.requestBody["name"], response.Name) + + // Verify owner and status are set + assert.Equal(t, tc.input.userID, response.Attributes["owner"]) + assert.Equal(t, OrderStatusCreated, response.Attributes["status"]) + + // Verify original attributes are preserved + if tc.input.requestBody["attributes"] != nil { + originalAttrs := tc.input.requestBody["attributes"].(map[string]any) + for key, value := range originalAttrs { + assert.Equal(t, value, response.Attributes[key]) + } + } + } + + // Assert log messages and levels + if tc.expectedLogMessage != "" { + assert.NotEmpty(t, testLogger.messages, "Expected log message but no logs were captured") + + // Check if the expected message exists in any of the captured messages + found := false + for i, message := range testLogger.messages { + if message == tc.expectedLogMessage { + assert.Equal(t, tc.expectedLogLevel, testLogger.levels[i]) + found = true + break + } + } + assert.True(t, found, "Expected log message '%s' not found in captured messages", tc.expectedLogMessage) + } + + mockRepo.AssertExpectations(t) + testLogger.reset() + }) + } +} + +func TestOrderHandler_CreateOrder_InvalidJSON(t *testing.T) { + testUserID := "test-user-123" + + testCases := map[string]struct { + requestBody string + expectedStatus int + expectedError string + }{ + "should return bad request when JSON is invalid": { + requestBody: `{"name": "test"`, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid request body", + }, + + "should return bad request when body is not JSON": { + requestBody: `invalid json`, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid request body", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Initialize mock and logger + mockRepo := &mockOrderRepository{} + testLogger := newTestLogger() + logger := testLogger.getLogger() + handler := NewOrderHandler(mockRepo, logger) + + // Prepare request + req := httptest.NewRequest(http.MethodPost, "/orders", bytes.NewBufferString(tc.requestBody)) + req.Header.Set("Content-Type", "application/json") + + // Set user ID in context + ctx := createContextWithUserID(testUserID) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + // Execute + handler.CreateOrder(w, req) + + // Assert + assert.Equal(t, tc.expectedStatus, w.Code) + + var errorResponse ErrorResponse + err := json.Unmarshal(w.Body.Bytes(), &errorResponse) + assert.NoError(t, err) + assert.Contains(t, errorResponse.Error, tc.expectedError) + + mockRepo.AssertExpectations(t) + }) + } +} + +func TestOrderHandler_GetOrderByID(t *testing.T) { + validOrderID := uuid.New() + testOrder := &domain.Order{ + ID: validOrderID, + Name: "Test Order", + Attributes: map[string]any{ + "priority": "high", + "category": "electronics", + }, + } + + testCases := map[string]struct { + input testGetOrderInput + expectedStatus int + expectedBody *domain.Order + expectedError string + expectedLogMessage string + expectedLogLevel slog.Level + }{ + "should return order successfully when order exists": { + input: testGetOrderInput{ + orderID: validOrderID.String(), + mockOrder: testOrder, + mockGetOrderError: nil, + }, + expectedStatus: http.StatusOK, + expectedBody: testOrder, + }, + + "should return not found when order does not exist": { + input: testGetOrderInput{ + orderID: validOrderID.String(), + mockOrder: nil, + mockGetOrderError: &repository.NotFoundError{ + Resource: "order", + Key: "id", + Value: validOrderID.String(), + }, + }, + expectedStatus: http.StatusNotFound, + expectedError: "The requested order could not be found", + expectedLogMessage: "Order not found", + expectedLogLevel: slog.LevelWarn, + }, + + "should return bad request when order ID is invalid": { + input: testGetOrderInput{ + orderID: "invalid-uuid", + }, + expectedStatus: http.StatusBadRequest, + expectedError: "ID must be a valid UUID", + }, + + "should return internal server error when repository fails": { + input: testGetOrderInput{ + orderID: validOrderID.String(), + mockOrder: nil, + mockGetOrderError: errors.New("database connection failed"), + }, + expectedStatus: http.StatusInternalServerError, + expectedError: "An internal error occurred while retrieving the order", + expectedLogMessage: "Failed to retrieve order", + expectedLogLevel: slog.LevelError, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Initialize mock and test logger + mockRepo := &mockOrderRepository{} + testLogger := newTestLogger() + logger := testLogger.getLogger() + handler := NewOrderHandler(mockRepo, logger) + + // Setup mock behavior based on input + if tc.input.orderID != "invalid-uuid" { + orderID, _ := uuid.Parse(tc.input.orderID) + mockRepo.On("GetOrderByID", mock.Anything, orderID).Return(tc.input.mockOrder, tc.input.mockGetOrderError) + } + + // Prepare request with mux router + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/orders/%s", tc.input.orderID), http.NoBody) + w := httptest.NewRecorder() + + // Setup mux router to extract path variables + router := http.NewServeMux() + router.HandleFunc("GET /orders/{id}", handler.GetOrderByID) + router.ServeHTTP(w, req) + + // Assert HTTP response + assert.Equal(t, tc.expectedStatus, w.Code) + + if tc.expectedError != "" { + var errorResponse ErrorResponse + err := json.Unmarshal(w.Body.Bytes(), &errorResponse) + assert.NoError(t, err) + assert.Contains(t, errorResponse.Message, tc.expectedError) + } else { + var response domain.Order + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, tc.expectedBody.ID, response.ID) + assert.Equal(t, tc.expectedBody.Name, response.Name) + assert.Equal(t, tc.expectedBody.Attributes, response.Attributes) + } + + // Assert log messages and levels + if tc.expectedLogMessage != "" { + assert.NotEmpty(t, testLogger.messages, "Expected log message but no logs were captured") + + // Check if the expected message exists in any of the captured messages + found := false + for i, message := range testLogger.messages { + if message == tc.expectedLogMessage { + assert.Equal(t, tc.expectedLogLevel, testLogger.levels[i]) + found = true + break + } + } + assert.True(t, found, "Expected log message '%s' not found in captured messages", tc.expectedLogMessage) + } + + mockRepo.AssertExpectations(t) + testLogger.reset() + }) + } +} diff --git a/examples/abac/internal/api/rest/handler/response.go b/examples/abac/internal/api/rest/handler/response.go new file mode 100644 index 0000000..187848c --- /dev/null +++ b/examples/abac/internal/api/rest/handler/response.go @@ -0,0 +1,28 @@ +package handler + +import ( + "encoding/json" + "net/http" +) + +// ErrorResponse represents an error response +type ErrorResponse struct { + Error string `json:"error"` + Message string `json:"message,omitempty"` +} + +// WriteJSONResponse writes a JSON response with the given status code and data +func WriteJSONResponse(w http.ResponseWriter, statusCode int, data any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + _ = json.NewEncoder(w).Encode(data) +} + +// WriteErrorResponse writes an error response with the given status code and message +func WriteErrorResponse(w http.ResponseWriter, statusCode int, err, message string) { + response := ErrorResponse{ + Error: err, + Message: message, + } + WriteJSONResponse(w, statusCode, response) +} diff --git a/examples/abac/internal/api/rest/middleware/jwt_auth.go b/examples/abac/internal/api/rest/middleware/jwt_auth.go new file mode 100644 index 0000000..a0c2b0f --- /dev/null +++ b/examples/abac/internal/api/rest/middleware/jwt_auth.go @@ -0,0 +1,183 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "slices" + "strings" + "time" + + "github.com/CameronXie/access-control-explorer/examples/abac/pkg/keyfetcher" + "github.com/golang-jwt/jwt/v5" +) + +type contextKey string + +const ( + BearerPrefix = "bearer" + DefaultClockSkewTolerance = 5 * time.Minute + UserIDContextKey contextKey = "user_id" +) + +// JWTAuthMiddleware handles JWT authentication and sets user ID in context +type JWTAuthMiddleware struct { + keyFetcher keyfetcher.PublicKeyFetcher + issuer string + audience string + clockSkew time.Duration +} + +// JWTConfig holds configuration for JWT authentication middleware +type JWTConfig struct { + KeyFetcher keyfetcher.PublicKeyFetcher + Issuer string + Audience string + ClockSkew time.Duration // Optional: defaults to DefaultClockSkewTolerance +} + +// NewJWTAuthMiddleware creates a new JWT authentication middleware +func NewJWTAuthMiddleware(config JWTConfig) *JWTAuthMiddleware { + clockSkew := config.ClockSkew + if clockSkew == 0 { + clockSkew = DefaultClockSkewTolerance + } + + return &JWTAuthMiddleware{ + keyFetcher: config.KeyFetcher, + issuer: config.Issuer, + audience: config.Audience, + clockSkew: clockSkew, + } +} + +// Handler returns an HTTP middleware function that validates JWT tokens +func (m *JWTAuthMiddleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userID, err := m.validateJWTAndExtractUserID(r) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Set user ID in context + ctx := context.WithValue(r.Context(), UserIDContextKey, userID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// validateJWTAndExtractUserID validates JWT token and returns user ID (subject) +func (m *JWTAuthMiddleware) validateJWTAndExtractUserID(r *http.Request) (string, error) { + token, err := m.parseToken(r) + if err != nil { + return "", err + } + + claims, ok := token.Claims.(*jwt.RegisteredClaims) + if !ok || !token.Valid { + return "", errors.New("invalid token") + } + + userID, err := m.validateClaims(claims) + if err != nil { + return "", fmt.Errorf("invalid claims: %w", err) + } + + return userID, nil +} + +// parseToken extracts and parses JWT token from request +func (m *JWTAuthMiddleware) parseToken(r *http.Request) (*jwt.Token, error) { + tokenString, err := extractBearerToken(r) + if err != nil { + return nil, err + } + + key, err := m.keyFetcher.FetchPublicKey() + if err != nil { + return nil, fmt.Errorf("failed to fetch public key: %w", err) + } + + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) { + // Ensure token uses RSA signing method + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return key, nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + return token, nil +} + +// validateClaims validates JWT claims and returns subject ID +func (m *JWTAuthMiddleware) validateClaims(claims *jwt.RegisteredClaims) (string, error) { + if err := m.validateRequiredClaims(claims); err != nil { + return "", err + } + + if err := m.validateTiming(claims); err != nil { + return "", err + } + + return claims.Subject, nil +} + +// validateRequiredClaims validates issuer, audience, and subject claims +func (m *JWTAuthMiddleware) validateRequiredClaims(claims *jwt.RegisteredClaims) error { + if claims.Subject == "" { + return errors.New("missing subject claim") + } + + if claims.Issuer != m.issuer { + return fmt.Errorf("invalid issuer: got %s, want %s", claims.Issuer, m.issuer) + } + + if !slices.Contains(claims.Audience, m.audience) { + return fmt.Errorf("invalid audience: missing %s", m.audience) + } + + return nil +} + +// validateTiming validates expiration and issued-at claims with clock skew tolerance +func (m *JWTAuthMiddleware) validateTiming(claims *jwt.RegisteredClaims) error { + now := time.Now() + + // Check expiration (required) + if claims.ExpiresAt == nil { + return errors.New("missing expiration claim") + } + + // Check issued-at time with clock skew tolerance (optional claim) + if claims.IssuedAt != nil && claims.IssuedAt.After(now.Add(m.clockSkew)) { + return errors.New("token issued too far in future") + } + + return nil +} + +// extractBearerToken extracts JWT token from Authorization header +func extractBearerToken(r *http.Request) (string, error) { + auth := r.Header.Get("Authorization") + if auth == "" { + return "", errors.New("missing authorization header") + } + + parts := strings.SplitN(auth, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], BearerPrefix) { + return "", errors.New("invalid authorization format") + } + + return parts[1], nil +} + +// GetUserIDFromContext extracts user ID from request context +func GetUserIDFromContext(ctx context.Context) (string, bool) { + userID, ok := ctx.Value(UserIDContextKey).(string) + return userID, ok +} diff --git a/examples/abac/internal/api/rest/middleware/jwt_auth_test.go b/examples/abac/internal/api/rest/middleware/jwt_auth_test.go new file mode 100644 index 0000000..9bfd931 --- /dev/null +++ b/examples/abac/internal/api/rest/middleware/jwt_auth_test.go @@ -0,0 +1,586 @@ +package middleware + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// mockKeyFetcher is a mock implementation of keyfetcher.PublicKeyFetcher +type mockKeyFetcher struct { + mock.Mock +} + +func (m *mockKeyFetcher) FetchPublicKey() (*rsa.PublicKey, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*rsa.PublicKey), args.Error(1) +} + +// Test helper functions +func generateTestKeyPair(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return privateKey, &privateKey.PublicKey +} + +func createTestToken(t *testing.T, privateKey *rsa.PrivateKey, claims *jwt.RegisteredClaims) string { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + return tokenString +} + +func createValidClaims(issuer, audience, subject string) *jwt.RegisteredClaims { + return &jwt.RegisteredClaims{ + Issuer: issuer, + Subject: subject, + Audience: []string{audience}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + } +} + +func TestNewJWTAuthMiddleware(t *testing.T) { + testCases := map[string]struct { + config JWTConfig + want time.Duration + }{ + "should use custom clock skew when provided": { + config: JWTConfig{ + KeyFetcher: &mockKeyFetcher{}, + Issuer: "test-issuer", + Audience: "test-audience", + ClockSkew: 10 * time.Minute, + }, + want: 10 * time.Minute, + }, + "should use default clock skew when not provided": { + config: JWTConfig{ + KeyFetcher: &mockKeyFetcher{}, + Issuer: "test-issuer", + Audience: "test-audience", + }, + want: DefaultClockSkewTolerance, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + middleware := NewJWTAuthMiddleware(tc.config) + assert.Equal(t, tc.want, middleware.clockSkew) + assert.Equal(t, tc.config.Issuer, middleware.issuer) + assert.Equal(t, tc.config.Audience, middleware.audience) + }) + } +} + +func TestJWTAuthMiddleware_Handler(t *testing.T) { + privateKey, publicKey := generateTestKeyPair(t) + now := time.Now() + + testCases := map[string]struct { + setupRequest func() *http.Request + setupMock func(*mockKeyFetcher) + expectedStatus int + expectedUserID string + }{ + "should authenticate successfully with valid token": { + setupRequest: func() *http.Request { + claims := createValidClaims("test-issuer", "test-audience", "userId") + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusOK, + expectedUserID: "userId", + }, + "should return unauthorized when authorization header is missing": { + setupRequest: func() *http.Request { + return httptest.NewRequest("GET", "/test", http.NoBody) + }, + setupMock: func(_ *mockKeyFetcher) { + // No mock setup needed + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when authorization format is invalid": { + setupRequest: func() *http.Request { + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "InvalidFormat token") + return req + }, + setupMock: func(_ *mockKeyFetcher) { + // No mock setup needed + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when key fetcher fails": { + setupRequest: func() *http.Request { + claims := createValidClaims("test-issuer", "test-audience", "user123") + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(nil, errors.New("key fetch error")) + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when token format is invalid": { + setupRequest: func() *http.Request { + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer invalid.token.format") + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when token uses wrong signing method": { + setupRequest: func() *http.Request { + // Create token with HMAC instead of RSA + claims := createValidClaims("test-issuer", "test-audience", "user123") + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte("secret")) + require.NoError(t, err) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+tokenString) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when token is expired": { + setupRequest: func() *http.Request { + claims := &jwt.RegisteredClaims{ + Issuer: "test-issuer", + Subject: "user123", + Audience: []string{"test-audience"}, + ExpiresAt: jwt.NewNumericDate(now.Add(-time.Hour)), // Expired + IssuedAt: jwt.NewNumericDate(now.Add(-2 * time.Hour)), + } + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when issuer is invalid": { + setupRequest: func() *http.Request { + claims := createValidClaims("wrong-issuer", "test-audience", "user123") + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when audience is invalid": { + setupRequest: func() *http.Request { + claims := createValidClaims("test-issuer", "wrong-audience", "user123") + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when subject is missing": { + setupRequest: func() *http.Request { + claims := &jwt.RegisteredClaims{ + Issuer: "test-issuer", + Subject: "", // Empty subject + Audience: []string{"test-audience"}, + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(now), + } + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when expiration is missing": { + setupRequest: func() *http.Request { + claims := &jwt.RegisteredClaims{ + Issuer: "test-issuer", + Subject: "user123", + Audience: []string{"test-audience"}, + // ExpiresAt: nil, // Missing expiration + IssuedAt: jwt.NewNumericDate(now), + } + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusUnauthorized, + }, + "should return unauthorized when token is issued too far in future": { + setupRequest: func() *http.Request { + claims := &jwt.RegisteredClaims{ + Issuer: "test-issuer", + Subject: "user123", + Audience: []string{"test-audience"}, + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), // Beyond tolerance + } + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusUnauthorized, + }, + "should accept token issued in future within tolerance": { + setupRequest: func() *http.Request { + claims := &jwt.RegisteredClaims{ + Issuer: "test-issuer", + Subject: "user123", + Audience: []string{"test-audience"}, + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(now.Add(1 * time.Minute)), // Within tolerance + } + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusOK, + expectedUserID: "user123", + }, + "should accept token with multiple audiences containing valid one": { + setupRequest: func() *http.Request { + claims := &jwt.RegisteredClaims{ + Issuer: "test-issuer", + Subject: "user123", + Audience: []string{"other-audience", "test-audience", "third-audience"}, + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(now), + } + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusOK, + expectedUserID: "user123", + }, + "should accept token without issued at claim": { + setupRequest: func() *http.Request { + claims := &jwt.RegisteredClaims{ + Issuer: "test-issuer", + Subject: "user123", + Audience: []string{"test-audience"}, + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), + // IssuedAt: nil, // No issued at claim + } + token := createTestToken(t, privateKey, claims) + req := httptest.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + return req + }, + setupMock: func(m *mockKeyFetcher) { + m.On("FetchPublicKey").Return(publicKey, nil) + }, + expectedStatus: http.StatusOK, + expectedUserID: "user123", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + mockKeyFetcher := &mockKeyFetcher{} + tc.setupMock(mockKeyFetcher) + + middleware := NewJWTAuthMiddleware(JWTConfig{ + KeyFetcher: mockKeyFetcher, + Issuer: "test-issuer", + Audience: "test-audience", + }) + + // Create a test handler that captures the user ID from context + var capturedUserID string + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if userID, ok := GetUserIDFromContext(r.Context()); ok { + capturedUserID = userID + } + w.WriteHeader(http.StatusOK) + }) + + handler := middleware.Handler(nextHandler) + req := tc.setupRequest() + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + if tc.expectedStatus == http.StatusOK { + assert.Equal(t, tc.expectedUserID, capturedUserID) + } + + mockKeyFetcher.AssertExpectations(t) + }) + } +} + +func TestExtractBearerToken(t *testing.T) { + testCases := map[string]struct { + authorization string + expectedToken string + expectedError bool + }{ + "should extract token from valid bearer header": { + authorization: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + expectedToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + }, + "should extract token from lowercase bearer header": { + authorization: "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + expectedToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + }, + "should extract token from mixed case bearer header": { + authorization: "BeArEr eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + expectedToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + }, + "should return error when authorization header is missing": { + authorization: "", + expectedError: true, + }, + "should return error when bearer token has no space": { + authorization: "Bearertoken", + expectedError: true, + }, + "should return error when authorization uses wrong scheme": { + authorization: "Basic dXNlcjpwYXNz", + expectedError: true, + }, + "should return error when only bearer is provided": { + authorization: "Bearer", + expectedError: true, + }, + "should handle token with spaces in it": { + authorization: "Bearer token extra", + expectedToken: "token extra", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", http.NoBody) + if tc.authorization != "" { + req.Header.Set("Authorization", tc.authorization) + } + + token, err := extractBearerToken(req) + + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedToken, token) + } + }) + } +} + +func TestGetUserIDFromContext(t *testing.T) { + testCases := map[string]struct { + setupCtx func() context.Context + expectedID string + expectedOK bool + }{ + "should extract user ID from context successfully": { + setupCtx: func() context.Context { + return context.WithValue(context.Background(), UserIDContextKey, "user123") + }, + expectedID: "user123", + expectedOK: true, + }, + "should return false when user ID is missing from context": { + setupCtx: func() context.Context { + return context.Background() + }, + expectedID: "", + expectedOK: false, + }, + "should return false when context value has wrong type": { + setupCtx: func() context.Context { + return context.WithValue(context.Background(), UserIDContextKey, 123) + }, + expectedID: "", + expectedOK: false, + }, + "should return false when context has different key": { + setupCtx: func() context.Context { + return context.WithValue(context.Background(), contextKey("different_key"), "user123") + }, + expectedID: "", + expectedOK: false, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + ctx := tc.setupCtx() + userID, ok := GetUserIDFromContext(ctx) + assert.Equal(t, tc.expectedID, userID) + assert.Equal(t, tc.expectedOK, ok) + }) + } +} + +func TestJWTAuthMiddleware_validateRequiredClaims(t *testing.T) { + middleware := NewJWTAuthMiddleware(JWTConfig{ + KeyFetcher: &mockKeyFetcher{}, + Issuer: "test-issuer", + Audience: "test-audience", + }) + + testCases := map[string]struct { + claims *jwt.RegisteredClaims + expectedError string + }{ + "should validate claims successfully": { + claims: &jwt.RegisteredClaims{ + Subject: "user123", + Issuer: "test-issuer", + Audience: []string{"test-audience"}, + }, + }, + "should return error when subject is missing": { + claims: &jwt.RegisteredClaims{ + Issuer: "test-issuer", + Audience: []string{"test-audience"}, + }, + expectedError: "missing subject claim", + }, + "should return error when issuer is wrong": { + claims: &jwt.RegisteredClaims{ + Subject: "user123", + Issuer: "wrong-issuer", + Audience: []string{"test-audience"}, + }, + expectedError: "invalid issuer", + }, + "should return error when audience is wrong": { + claims: &jwt.RegisteredClaims{ + Subject: "user123", + Issuer: "test-issuer", + Audience: []string{"wrong-audience"}, + }, + expectedError: "invalid audience", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + err := middleware.validateRequiredClaims(tc.claims) + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestJWTAuthMiddleware_validateTiming(t *testing.T) { + middleware := NewJWTAuthMiddleware(JWTConfig{ + KeyFetcher: &mockKeyFetcher{}, + Issuer: "test-issuer", + Audience: "test-audience", + ClockSkew: 5 * time.Minute, + }) + + now := time.Now() + + testCases := map[string]struct { + claims *jwt.RegisteredClaims + expectedError string + }{ + "should validate timing successfully": { + claims: &jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(now.Add(-time.Minute)), + }, + }, + "should return error when expiration is missing": { + claims: &jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(now), + }, + expectedError: "missing expiration claim", + }, + "should return error when token is issued too far in future": { + claims: &jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), + }, + expectedError: "token issued too far in future", + }, + "should accept token without issued at claim": { + claims: &jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + err := middleware.validateTiming(tc.claims) + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/examples/abac/internal/domain/order.go b/examples/abac/internal/domain/order.go new file mode 100644 index 0000000..fc20d89 --- /dev/null +++ b/examples/abac/internal/domain/order.go @@ -0,0 +1,10 @@ +package domain + +import "github.com/google/uuid" + +// Order represents an order entity with ID, name, and flexible attributes +type Order struct { + ID uuid.UUID `json:"id"` + Name string `json:"name"` + Attributes map[string]any `json:"attributes"` +} diff --git a/examples/abac/internal/enforcer/enforcer.go b/examples/abac/internal/enforcer/enforcer.go new file mode 100644 index 0000000..31d3463 --- /dev/null +++ b/examples/abac/internal/enforcer/enforcer.go @@ -0,0 +1,261 @@ +package enforcer + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" +) + +// AdviceHandler defines the interface for handling advice +type AdviceHandler interface { + Handle(ctx context.Context, advice ro.Advice, w http.ResponseWriter, r *http.Request) error +} + +// ObligationHandler defines the interface for handling obligations +type ObligationHandler interface { + Handle(ctx context.Context, obligation ro.Obligation, w http.ResponseWriter, r *http.Request) error +} + +// RequestExtractor defines the interface for extracting access request from HTTP request +type RequestExtractor interface { + Extract(ctx context.Context, r *http.Request) (*ro.AccessRequest, error) +} + +// ErrorResponse represents a standardized error response +type ErrorResponse struct { + Error string `json:"error"` + Message string `json:"message"` +} + +// Enforcer represents the Policy Enforcement Point middleware +// Note: PEP logging here is operational (observability and correlation). +// It is intentionally distinct from any auditing performed via obligations. +type Enforcer struct { + orchestrator ro.RequestOrchestrator + requestExtractor RequestExtractor + adviceHandlers map[string]AdviceHandler + obligationHandlers map[string]ObligationHandler + errorHandler func(w http.ResponseWriter, r *http.Request, statusCode int, errorResp ErrorResponse) + logger *slog.Logger +} + +// Option defines configuration options for Enforcer +type Option func(*Enforcer) + +// NewEnforcer creates a new Enforcer instance with the given request orchestrator and options +func NewEnforcer(orchestrator ro.RequestOrchestrator, extractor RequestExtractor, logger *slog.Logger, options ...Option) *Enforcer { + enforcer := &Enforcer{ + orchestrator: orchestrator, + requestExtractor: extractor, + adviceHandlers: make(map[string]AdviceHandler), + obligationHandlers: make(map[string]ObligationHandler), + errorHandler: defaultErrorHandler, + logger: logger, + } + + for _, option := range options { + option(enforcer) + } + + return enforcer +} + +// WithAdviceHandler registers an advice handler for a specific advice ID +func WithAdviceHandler(adviceID string, handler AdviceHandler) Option { + return func(e *Enforcer) { + e.adviceHandlers[adviceID] = handler + } +} + +// WithObligationHandler registers an obligation handler for a specific obligation ID +func WithObligationHandler(obligationID string, handler ObligationHandler) Option { + return func(e *Enforcer) { + e.obligationHandlers[obligationID] = handler + } +} + +// WithErrorHandler sets a custom error response handler for consistent error formatting +func WithErrorHandler(handler func(w http.ResponseWriter, r *http.Request, statusCode int, errorResp ErrorResponse)) Option { + return func(e *Enforcer) { + e.errorHandler = handler + } +} + +// Enforce returns an HTTP middleware that enforces access control. +func (e *Enforcer) Enforce(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + ctx := r.Context() + reqLogger := e.logger.With( + "method", r.Method, + "path", r.URL.Path, + "remote_addr", r.RemoteAddr, + ) + + // Extract access request from HTTP request + accessReq, err := e.requestExtractor.Extract(ctx, r) + if err != nil { + reqLogger.ErrorContext(ctx, "request_extraction_failed", + slog.String("error", err.Error()), + ) + e.errorHandler(w, r, http.StatusBadRequest, ErrorResponse{ + Error: "request_extraction_failed", + Message: "Invalid access request", + }) + return + } + + // Evaluate access using the request orchestrator + accessResp, err := e.orchestrator.EvaluateAccess(ctx, accessReq) + if err != nil { + reqLogger.ErrorContext(ctx, "access_evaluation_failed", + slog.String("error", err.Error()), + slog.Duration("duration_ms", time.Since(start)), + ) + e.errorHandler(w, r, http.StatusInternalServerError, ErrorResponse{ + Error: "access_evaluation_failed", + Message: "An internal error occurred while evaluating access", + }) + return + } + + // Update the request-scoped logger with request ID for correlation + reqLogger = reqLogger.With("access_request_id", accessResp.RequestID.String()) + + // Handle decision + switch accessResp.Decision { + case ro.Permit: + // Handle obligations before allowing access + if err := e.handleObligations(ctx, accessResp.Obligations, w, r); err != nil { + reqLogger.ErrorContext(ctx, "obligation_failed", + slog.String("error", err.Error()), + slog.Int("obligations_count", len(accessResp.Obligations)), + slog.Int("advices_count", len(accessResp.Advices)), + slog.String("decision", string(ro.Permit)), + slog.Duration("duration_ms", time.Since(start)), + ) + e.errorHandler(w, r, http.StatusInternalServerError, ErrorResponse{ + Error: "obligation_failed", + Message: "An internal error occurred while enforcing obligations", + }) + return + } + + // Handle advice (non-blocking) + if err := e.handleAdvice(ctx, accessResp.Advices, w, r); err != nil { + reqLogger.WarnContext(ctx, "advice_failed", + slog.String("error", err.Error()), + slog.Int("advices_count", len(accessResp.Advices)), + slog.String("decision", string(ro.Permit)), + ) + } + + reqLogger.InfoContext(ctx, "access_permitted", + slog.Int("obligations_count", len(accessResp.Obligations)), + slog.Int("advices_count", len(accessResp.Advices)), + slog.String("decision", string(ro.Permit)), + slog.Duration("duration_ms", time.Since(start)), + ) + + // Allow access to the protected resource + next.ServeHTTP(w, r) + + case ro.Deny: + if err := e.handleObligations(ctx, accessResp.Obligations, w, r); err != nil { + reqLogger.WarnContext(ctx, "obligation_failed_on_deny", + slog.String("error", err.Error()), + slog.Int("obligations_count", len(accessResp.Obligations)), + slog.String("decision", string(ro.Deny)), + ) + } + if err := e.handleAdvice(ctx, accessResp.Advices, w, r); err != nil { + reqLogger.WarnContext(ctx, "advice_failed_on_deny", + slog.String("error", err.Error()), + slog.Int("advices_count", len(accessResp.Advices)), + slog.String("decision", string(ro.Deny)), + ) + } + + reqLogger.InfoContext(ctx, "access_denied", + slog.Int("obligations_count", len(accessResp.Obligations)), + slog.Int("advices_count", len(accessResp.Advices)), + slog.String("decision", string(ro.Deny)), + slog.Duration("duration_ms", time.Since(start)), + ) + + e.errorHandler(w, r, http.StatusForbidden, ErrorResponse{ + Error: "access_denied", + Message: "You do not have permission to access this resource", + }) + + case ro.NotApplicable: + reqLogger.InfoContext(ctx, "access_not_applicable", + slog.String("decision", string(ro.NotApplicable)), + slog.Int("obligations_count", len(accessResp.Obligations)), + slog.Int("advices_count", len(accessResp.Advices)), + slog.Duration("duration_ms", time.Since(start)), + ) + + e.errorHandler(w, r, http.StatusForbidden, ErrorResponse{ + Error: "access_denied", + Message: "You do not have permission to access this resource", + }) + + case ro.Indeterminate: + reqLogger.ErrorContext(ctx, "access_indeterminate", + slog.String("decision", string(ro.Indeterminate)), + slog.String("status_code", string(accessResp.Status.Code)), + slog.String("status_message", accessResp.Status.Message), + slog.Duration("duration_ms", time.Since(start)), + ) + + e.errorHandler(w, r, http.StatusInternalServerError, ErrorResponse{ + Error: "indeterminate_decision", + Message: "An internal error occurred while processing the access decision", + }) + } + }) +} + +// handleObligations processes all obligations that must be fulfilled +func (e *Enforcer) handleObligations(ctx context.Context, obligations []ro.Obligation, w http.ResponseWriter, r *http.Request) error { + for _, obligation := range obligations { + handler, exists := e.obligationHandlers[obligation.ID] + if !exists { + return fmt.Errorf("no handler registered for obligation ID: %s", obligation.ID) + } + + if err := handler.Handle(ctx, obligation, w, r); err != nil { + return fmt.Errorf("obligation handler failed for ID %s: %w", obligation.ID, err) + } + } + return nil +} + +// handleAdvice processes all advice (non-blocking suggestions) +func (e *Enforcer) handleAdvice(ctx context.Context, advices []ro.Advice, w http.ResponseWriter, r *http.Request) error { + for _, advice := range advices { + handler, exists := e.adviceHandlers[advice.ID] + if !exists { + // Advice is optional, so missing handlers are not errors + continue + } + + if err := handler.Handle(ctx, advice, w, r); err != nil { + return fmt.Errorf("advice handler failed for ID %s: %w", advice.ID, err) + } + } + return nil +} + +// defaultErrorHandler provides a consistent error response format +func defaultErrorHandler(w http.ResponseWriter, _ *http.Request, statusCode int, errorResp ErrorResponse) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + _ = json.NewEncoder(w).Encode(errorResp) +} diff --git a/examples/abac/internal/enforcer/enforcer_test.go b/examples/abac/internal/enforcer/enforcer_test.go new file mode 100644 index 0000000..510ea93 --- /dev/null +++ b/examples/abac/internal/enforcer/enforcer_test.go @@ -0,0 +1,615 @@ +//nolint:lll // unit test +package enforcer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" +) + +// Mock implementations +type mockRequestOrchestrator struct { + mock.Mock +} + +func (m *mockRequestOrchestrator) EvaluateAccess(ctx context.Context, req *ro.AccessRequest) (*ro.AccessResponse, error) { + args := m.Called(ctx, req) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*ro.AccessResponse), args.Error(1) +} + +type mockRequestExtractor struct { + mock.Mock +} + +func (m *mockRequestExtractor) Extract(ctx context.Context, r *http.Request) (*ro.AccessRequest, error) { + args := m.Called(ctx, r) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*ro.AccessRequest), args.Error(1) +} + +type mockObligationHandler struct { + mock.Mock +} + +func (m *mockObligationHandler) Handle(ctx context.Context, obligation ro.Obligation, w http.ResponseWriter, r *http.Request) error { + args := m.Called(ctx, obligation, w, r) + return args.Error(0) +} + +type mockAdviceHandler struct { + mock.Mock +} + +func (m *mockAdviceHandler) Handle(ctx context.Context, advice ro.Advice, w http.ResponseWriter, r *http.Request) error { + args := m.Called(ctx, advice, w, r) + return args.Error(0) +} + +// Test logger that captures log messages and levels +type testLogHandler struct { + messages []string + levels []slog.Level +} + +func (*testLogHandler) Enabled(context.Context, slog.Level) bool { + return true +} + +func (h *testLogHandler) Handle(_ context.Context, r slog.Record) error { //nolint:gocritic // slog.Handler interface + h.messages = append(h.messages, r.Message) + h.levels = append(h.levels, r.Level) + return nil +} + +func (h *testLogHandler) WithAttrs(_ []slog.Attr) slog.Handler { + return h +} + +func (h *testLogHandler) WithGroup(_ string) slog.Handler { + return h +} + +type expectedLog struct { + level slog.Level + contains string +} + +func TestEnforcer_Enforce(t *testing.T) { //nolint:gocyclo // unit test + // Common helpers to reduce duplication across similar cases + baseDeleteDocRequest := &ro.AccessRequest{ + Subject: ro.Subject{ID: "user123", Type: "user"}, + Resource: ro.Resource{ID: "doc456", Type: "document"}, + Action: ro.Action{ID: "delete"}, + } + baseReadDocRequest := &ro.AccessRequest{ + Subject: ro.Subject{ID: "user123", Type: "user"}, + Resource: ro.Resource{ID: "doc456", Type: "document"}, + Action: ro.Action{ID: "read"}, + } + newBaseDenyResponse := func() *ro.AccessResponse { + return &ro.AccessResponse{ + RequestID: uuid.New(), + Decision: ro.Deny, + Status: ro.Status{Code: "denied", Message: "Insufficient permissions"}, + EvaluatedAt: time.Now(), + } + } + newBaseNotApplicableResponse := func() *ro.AccessResponse { + return &ro.AccessResponse{ + RequestID: uuid.New(), + Decision: ro.NotApplicable, + Status: ro.Status{Code: "PolicyNotFound", Message: "no applicable policy was found for this request"}, + EvaluatedAt: time.Now(), + } + } + + testCases := map[string]struct { + // Request setup + httpMethod string + httpPath string + + // Mock behaviors + extractorResult *ro.AccessRequest + extractorError error + orchestratorResult *ro.AccessResponse + orchestratorError error + obligationHandlers map[string]error // obligationID -> error (nil means success) + adviceHandlers map[string]error // adviceID -> error (nil means success) + + // Expected results + expectedStatus int + expectedErrorResp *ErrorResponse // For JSON error responses + nextCalled bool + expectedLogs []expectedLog + }{ + "should allow access when decision is permit and obligations succeed": { + httpMethod: "GET", + httpPath: "/api/documents/123", + extractorResult: baseReadDocRequest, + orchestratorResult: &ro.AccessResponse{ + RequestID: uuid.New(), + Decision: ro.Permit, + Status: ro.Status{Code: "ok", Message: "Permitted"}, + EvaluatedAt: time.Now(), + Obligations: []ro.Obligation{ + {ID: "audit", Attributes: map[string]any{"level": "INFO", "message": "ok"}}, + }, + Advices: []ro.Advice{ + {ID: "analytics", Attributes: map[string]any{"track": true}}, + }, + }, + obligationHandlers: map[string]error{ + "audit": nil, + }, + adviceHandlers: map[string]error{ + "analytics": nil, + }, + expectedStatus: 200, + nextCalled: true, + expectedLogs: []expectedLog{ + {level: slog.LevelInfo, contains: "access_permitted"}, + }, + }, + + "should deny access when decision is deny": { + httpMethod: "DELETE", + httpPath: "/api/documents/123", + extractorResult: baseDeleteDocRequest, + orchestratorResult: newBaseDenyResponse(), + expectedStatus: http.StatusForbidden, + expectedErrorResp: &ErrorResponse{ + Error: "access_denied", + Message: "You do not have permission to access this resource", + }, + nextCalled: false, + expectedLogs: []expectedLog{ + {level: slog.LevelInfo, contains: "access_denied"}, + }, + }, + + "should log warn when obligation handler fails on deny": { + httpMethod: "DELETE", + httpPath: "/api/documents/123", + extractorResult: baseDeleteDocRequest, + orchestratorResult: func() *ro.AccessResponse { + resp := newBaseDenyResponse() + resp.Obligations = []ro.Obligation{ + {ID: "deny-audit", Attributes: map[string]any{"level": "WARN", "message": "deny audit"}}, + } + return resp + }(), + obligationHandlers: map[string]error{ + "deny-audit": errors.New("audit sink unavailable"), + }, + expectedStatus: http.StatusForbidden, + expectedErrorResp: &ErrorResponse{ + Error: "access_denied", + Message: "You do not have permission to access this resource", + }, + nextCalled: false, + expectedLogs: []expectedLog{ + {level: slog.LevelWarn, contains: "obligation_failed_on_deny"}, + {level: slog.LevelInfo, contains: "access_denied"}, + }, + }, + + "should log warn when advice handler fails on deny": { + httpMethod: "DELETE", + httpPath: "/api/documents/123", + extractorResult: baseDeleteDocRequest, + orchestratorResult: func() *ro.AccessResponse { + resp := newBaseDenyResponse() + resp.Advices = []ro.Advice{ + {ID: "notify-admin", Attributes: map[string]any{"reason": "unauthorized_access"}}, + } + return resp + }(), + adviceHandlers: map[string]error{ + "notify-admin": errors.New("notification service unavailable"), + }, + expectedStatus: http.StatusForbidden, + expectedErrorResp: &ErrorResponse{ + Error: "access_denied", + Message: "You do not have permission to access this resource", + }, + nextCalled: false, + expectedLogs: []expectedLog{ + {level: slog.LevelWarn, contains: "advice_failed_on_deny"}, + {level: slog.LevelInfo, contains: "access_denied"}, + }, + }, + + "should return 400 when request extraction fails": { + httpMethod: "GET", + httpPath: "/api/documents/123", + extractorError: errors.New("missing required headers"), + expectedStatus: 400, + expectedErrorResp: &ErrorResponse{ + Error: "request_extraction_failed", + Message: "Invalid access request", + }, + nextCalled: false, + expectedLogs: []expectedLog{ + {level: slog.LevelError, contains: "request_extraction_failed"}, + }, + }, + + "should return 500 when access evaluation fails": { + httpMethod: "GET", + httpPath: "/api/documents/123", + extractorResult: baseReadDocRequest, + orchestratorError: errors.New("PDP service unavailable"), + expectedStatus: 500, + expectedErrorResp: &ErrorResponse{ + Error: "access_evaluation_failed", + Message: "An internal error occurred while evaluating access", + }, + nextCalled: false, + expectedLogs: []expectedLog{ + {level: slog.LevelError, contains: "access_evaluation_failed"}, + }, + }, + + "should return 500 when obligation handler is not registered": { + httpMethod: "GET", + httpPath: "/api/documents/123", + extractorResult: baseReadDocRequest, + orchestratorResult: &ro.AccessResponse{ + RequestID: uuid.New(), + Decision: ro.Permit, + Status: ro.Status{Code: "ok", Message: "Permitted"}, + EvaluatedAt: time.Now(), + Obligations: []ro.Obligation{ + {ID: "missing-handler", Attributes: map[string]any{"level": "INFO", "message": "x"}}, + }, + }, + obligationHandlers: map[string]error{}, // No handler registered + expectedStatus: 500, + expectedErrorResp: &ErrorResponse{ + Error: "obligation_failed", + Message: "An internal error occurred while enforcing obligations", + }, + nextCalled: false, + expectedLogs: []expectedLog{ + {level: slog.LevelError, contains: "obligation_failed"}, + }, + }, + + "should return 500 when obligation handler fails": { + httpMethod: "GET", + httpPath: "/api/documents/123", + extractorResult: baseReadDocRequest, + orchestratorResult: &ro.AccessResponse{ + RequestID: uuid.New(), + Decision: ro.Permit, + Status: ro.Status{Code: "ok", Message: "Permitted"}, + EvaluatedAt: time.Now(), + Obligations: []ro.Obligation{ + {ID: "audit", Attributes: map[string]any{"level": "INFO", "message": "x"}}, + }, + }, + obligationHandlers: map[string]error{ + "audit": errors.New("audit service unavailable"), + }, + expectedStatus: 500, + expectedErrorResp: &ErrorResponse{ + Error: "obligation_failed", + Message: "An internal error occurred while enforcing obligations", + }, + nextCalled: false, + expectedLogs: []expectedLog{ + {level: slog.LevelError, contains: "obligation_failed"}, + }, + }, + + "should return 500 when decision is indeterminate": { + httpMethod: "GET", + httpPath: "/api/documents/123", + extractorResult: baseReadDocRequest, + orchestratorResult: &ro.AccessResponse{ + RequestID: uuid.New(), + Decision: ro.Indeterminate, + Status: ro.Status{Code: "syntax_error", Message: "Policy syntax error"}, + EvaluatedAt: time.Now(), + }, + expectedStatus: 500, + expectedErrorResp: &ErrorResponse{ + Error: "indeterminate_decision", + Message: "An internal error occurred while processing the access decision", + }, + nextCalled: false, + expectedLogs: []expectedLog{ + {level: slog.LevelError, contains: "access_indeterminate"}, + }, + }, + + "should warn when advice handler fails but continue processing": { + httpMethod: "GET", + httpPath: "/api/documents/123", + extractorResult: baseReadDocRequest, + orchestratorResult: &ro.AccessResponse{ + RequestID: uuid.New(), + Decision: ro.Permit, + Status: ro.Status{Code: "ok", Message: "Permitted"}, + EvaluatedAt: time.Now(), + Advices: []ro.Advice{ + {ID: "analytics", Attributes: map[string]any{"track": true}}, + }, + }, + adviceHandlers: map[string]error{ + "analytics": errors.New("analytics service timeout"), + }, + expectedStatus: 200, + nextCalled: true, + expectedLogs: []expectedLog{ + {level: slog.LevelWarn, contains: "advice_failed"}, + {level: slog.LevelInfo, contains: "access_permitted"}, + }, + }, + + "should skip advice when no handler is registered": { + httpMethod: "GET", + httpPath: "/api/documents/123", + extractorResult: baseReadDocRequest, + orchestratorResult: &ro.AccessResponse{ + RequestID: uuid.New(), + Decision: ro.Permit, + Status: ro.Status{Code: "ok", Message: "Permitted"}, + EvaluatedAt: time.Now(), + Advices: []ro.Advice{ + {ID: "unhandled-advice", Attributes: map[string]any{"track": true}}, + }, + }, + adviceHandlers: map[string]error{}, // No handler registered - should be silently skipped + expectedStatus: 200, + nextCalled: true, + expectedLogs: []expectedLog{ + {level: slog.LevelInfo, contains: "access_permitted"}, + }, + }, + + "should handle multiple obligations and advices": { + httpMethod: "POST", + httpPath: "/api/documents", + extractorResult: &ro.AccessRequest{ + Subject: ro.Subject{ID: "user123", Type: "user"}, + Resource: ro.Resource{ID: "documents", Type: "collection"}, + Action: ro.Action{ID: "create"}, + }, + orchestratorResult: &ro.AccessResponse{ + RequestID: uuid.New(), + Decision: ro.Permit, + Status: ro.Status{Code: "ok", Message: "Permitted"}, + EvaluatedAt: time.Now(), + Obligations: []ro.Obligation{ + {ID: "audit", Attributes: map[string]any{"level": "INFO", "message": "created"}}, + {ID: "encryption", Attributes: map[string]any{"algorithm": "AES256"}}, + }, + Advices: []ro.Advice{ + {ID: "analytics", Attributes: map[string]any{"track": true}}, + {ID: "cache-invalidation", Attributes: map[string]any{"keys": []string{"documents"}}}, + }, + }, + obligationHandlers: map[string]error{ + "audit": nil, + "encryption": nil, + }, + adviceHandlers: map[string]error{ + "analytics": nil, + "cache-invalidation": nil, + }, + expectedStatus: 200, + nextCalled: true, + expectedLogs: []expectedLog{ + {level: slog.LevelInfo, contains: "access_permitted"}, + }, + }, + + "should return 403 when decision is not applicable and log info": { + httpMethod: "GET", + httpPath: "/api/unknown", + extractorResult: &ro.AccessRequest{ + Subject: ro.Subject{ID: "user123", Type: "user"}, + Resource: ro.Resource{ID: "unknown", Type: "unknown"}, + Action: ro.Action{ID: "read"}, + }, + orchestratorResult: newBaseNotApplicableResponse(), + expectedStatus: http.StatusForbidden, + expectedErrorResp: &ErrorResponse{ + Error: "access_denied", + Message: "You do not have permission to access this resource", + }, + nextCalled: false, + expectedLogs: []expectedLog{ + {level: slog.LevelInfo, contains: "access_not_applicable"}, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Initialize mocks in test loop + orchestrator := &mockRequestOrchestrator{} + extractor := &mockRequestExtractor{} + obligationHandlers := make(map[string]*mockObligationHandler) + adviceHandlers := make(map[string]*mockAdviceHandler) + + // Setup extractor mock + extractor.On("Extract", mock.Anything, mock.Anything).Return(tc.extractorResult, tc.extractorError) + + // Setup orchestrator mock + if tc.extractorResult != nil { + orchestrator.On("EvaluateAccess", mock.Anything, tc.extractorResult).Return(tc.orchestratorResult, tc.orchestratorError) + } + + // Setup obligation handler mocks + if tc.orchestratorResult != nil { + for obligationID, expectedError := range tc.obligationHandlers { + handler := &mockObligationHandler{} + for _, obligation := range tc.orchestratorResult.Obligations { + if obligation.ID == obligationID { + handler.On("Handle", mock.Anything, obligation, mock.Anything, mock.Anything).Return(expectedError) + break + } + } + obligationHandlers[obligationID] = handler + } + } + + // Setup advice handler mocks + if tc.orchestratorResult != nil { + for adviceID, expectedError := range tc.adviceHandlers { + handler := &mockAdviceHandler{} + for _, advice := range tc.orchestratorResult.Advices { + if advice.ID == adviceID { + handler.On("Handle", mock.Anything, advice, mock.Anything, mock.Anything).Return(expectedError) + break + } + } + adviceHandlers[adviceID] = handler + } + } + + // Create test logger + logHandler := &testLogHandler{} + logger := slog.New(logHandler) + + // Create enforcer options + options := make([]Option, 0) + for obligationID, handler := range obligationHandlers { + options = append(options, WithObligationHandler(obligationID, handler)) + } + for adviceID, handler := range adviceHandlers { + options = append(options, WithAdviceHandler(adviceID, handler)) + } + + // Initialize enforcer in test loop + enforcer := NewEnforcer(orchestrator, extractor, logger, options...) + + // Setup next handler to track if it was called + nextCalled := false + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + // Create test request and response recorder + req := httptest.NewRequest(tc.httpMethod, tc.httpPath, http.NoBody) + req.RemoteAddr = "192.168.1.1:8080" + recorder := httptest.NewRecorder() + + // Execute + middleware := enforcer.Enforce(nextHandler) + middleware.ServeHTTP(recorder, req) + + // Assert HTTP response status + assert.Equal(t, tc.expectedStatus, recorder.Code) + + // Assert response body for error cases + if tc.expectedErrorResp != nil { + assert.Equal(t, "application/json", recorder.Header().Get("Content-Type")) + var actualErrorResp ErrorResponse + err := json.Unmarshal(recorder.Body.Bytes(), &actualErrorResp) + assert.NoError(t, err, "Response should be valid JSON") + assert.Equal(t, *tc.expectedErrorResp, actualErrorResp) + } + + // Assert next handler was called appropriately + assert.Equal(t, tc.nextCalled, nextCalled) + + // Assert logs: check both message content and level + for _, expected := range tc.expectedLogs { + found := false + for i, msg := range logHandler.messages { + if strings.Contains(msg, expected.contains) && logHandler.levels[i] == expected.level { + found = true + break + } + } + assert.True(t, found, fmt.Sprintf("Expected log with level %v containing '%s' not found. Actual: %v", expected.level, expected.contains, logHandler.messages)) + } + + // Verify all mocks were called as expected + orchestrator.AssertExpectations(t) + extractor.AssertExpectations(t) + for _, handler := range obligationHandlers { + handler.AssertExpectations(t) + } + for _, handler := range adviceHandlers { + handler.AssertExpectations(t) + } + }) + } +} + +func TestEnforcer_WithCustomErrorHandler(t *testing.T) { + // Test that custom error handler is used correctly + orchestrator := &mockRequestOrchestrator{} + extractor := &mockRequestExtractor{} + + // Setup mocks to trigger an error + extractor.On("Extract", mock.Anything, mock.Anything).Return(nil, errors.New("test error")) + + // Custom error handler that adds extra field + customErrorHandler := func(w http.ResponseWriter, _ *http.Request, statusCode int, errorResp ErrorResponse) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Custom-Header", "test-value") + w.WriteHeader(statusCode) + + // Add timestamp to error response + response := map[string]any{ + "error": errorResp.Error, + "message": errorResp.Message, + "timestamp": "2024-01-01T00:00:00Z", // Fixed for testing + } + assert.NoError(t, json.NewEncoder(w).Encode(response)) + } + + logHandler := &testLogHandler{} + logger := slog.New(logHandler) + + // Create enforcer with custom error handler + enforcer := NewEnforcer(orchestrator, extractor, logger, WithErrorHandler(customErrorHandler)) + + // Create test request + req := httptest.NewRequest("GET", "/test", http.NoBody) + recorder := httptest.NewRecorder() + + // Execute + nextHandler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + t.Error("Next handler should not be called") + }) + middleware := enforcer.Enforce(nextHandler) + middleware.ServeHTTP(recorder, req) + + // Assert custom error handler was used + assert.Equal(t, 400, recorder.Code) + assert.Equal(t, "application/json", recorder.Header().Get("Content-Type")) + assert.Equal(t, "test-value", recorder.Header().Get("X-Custom-Header")) + + // Assert custom response format + var response map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "request_extraction_failed", response["error"]) + assert.Equal(t, "Invalid access request", response["message"]) + assert.Equal(t, "2024-01-01T00:00:00Z", response["timestamp"]) +} diff --git a/examples/abac/internal/enforcer/jwt/subject_extractor.go b/examples/abac/internal/enforcer/jwt/subject_extractor.go new file mode 100644 index 0000000..fd74147 --- /dev/null +++ b/examples/abac/internal/enforcer/jwt/subject_extractor.go @@ -0,0 +1,33 @@ +package jwt + +import ( + "context" + "errors" + "net/http" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/api/rest/middleware" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/enforcer" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/infoprovider" +) + +// subjectExtractor extracts subject information from context (set by JWT middleware) +type subjectExtractor struct{} + +// NewSubjectExtractor creates a new subject extractor that reads from context +func NewSubjectExtractor() enforcer.SubjectExtractor { + return &subjectExtractor{} +} + +// Extract retrieves subject information from request context +func (*subjectExtractor) Extract(_ context.Context, r *http.Request) (*ro.Subject, error) { + userID, ok := middleware.GetUserIDFromContext(r.Context()) + if !ok { + return nil, errors.New("user ID not found in context") + } + + return &ro.Subject{ + ID: userID, + Type: string(infoprovider.InfoTypeUser), + }, nil +} diff --git a/examples/abac/internal/enforcer/jwt/subject_extractor_test.go b/examples/abac/internal/enforcer/jwt/subject_extractor_test.go new file mode 100644 index 0000000..9290a27 --- /dev/null +++ b/examples/abac/internal/enforcer/jwt/subject_extractor_test.go @@ -0,0 +1,163 @@ +package jwt + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/api/rest/middleware" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/infoprovider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testContextKey string + +func TestSubjectExtractor_Extract(t *testing.T) { + testCases := map[string]struct { + setupContext func() context.Context + setupRequest func(ctx context.Context) *http.Request + expectedSubject *ro.Subject + expectedError string + }{ + "should extract subject successfully when user ID exists in context": { + setupContext: func() context.Context { + return context.WithValue(context.Background(), middleware.UserIDContextKey, "user123") + }, + setupRequest: func(ctx context.Context) *http.Request { + return httptest.NewRequest("GET", "/test", http.NoBody).WithContext(ctx) + }, + expectedSubject: &ro.Subject{ + ID: "user123", + Type: string(infoprovider.InfoTypeUser), + }, + }, + + "should return error when user ID is not found in context": { + setupContext: func() context.Context { + return context.Background() + }, + setupRequest: func(ctx context.Context) *http.Request { + return httptest.NewRequest("GET", "/test", http.NoBody).WithContext(ctx) + }, + expectedError: "user ID not found in context", + }, + + "should return error when user ID has wrong type in context": { + setupContext: func() context.Context { + return context.WithValue(context.Background(), middleware.UserIDContextKey, 123) + }, + setupRequest: func(ctx context.Context) *http.Request { + return httptest.NewRequest("GET", "/test", http.NoBody).WithContext(ctx) + }, + expectedError: "user ID not found in context", + }, + + "should return error when context has different key": { + setupContext: func() context.Context { + return context.WithValue(context.Background(), testContextKey("different_key"), "user123") + }, + setupRequest: func(ctx context.Context) *http.Request { + return httptest.NewRequest("GET", "/test", http.NoBody).WithContext(ctx) + }, + expectedError: "user ID not found in context", + }, + + "should extract subject with empty user ID if that exists in context": { + setupContext: func() context.Context { + return context.WithValue(context.Background(), middleware.UserIDContextKey, "") + }, + setupRequest: func(ctx context.Context) *http.Request { + return httptest.NewRequest("GET", "/test", http.NoBody).WithContext(ctx) + }, + expectedSubject: &ro.Subject{ + ID: "", + Type: string(infoprovider.InfoTypeUser), + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + extractor := NewSubjectExtractor() + ctx := tc.setupContext() + req := tc.setupRequest(ctx) + + result, err := extractor.Extract(context.Background(), req) + + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tc.expectedSubject.ID, result.ID) + assert.Equal(t, tc.expectedSubject.Type, result.Type) + } + }) + } +} + +func TestSubjectExtractor_ExtractContextParameter(t *testing.T) { + testCases := map[string]struct { + contextParam context.Context + setupRequest func() *http.Request + expectedSubject *ro.Subject + expectedError string + }{ + "should ignore context parameter and use request context instead": { + contextParam: context.WithValue(context.Background(), middleware.UserIDContextKey, "wrong-user"), + setupRequest: func() *http.Request { + ctx := context.WithValue(context.Background(), middleware.UserIDContextKey, "correct-user") + return httptest.NewRequest("GET", "/test", http.NoBody).WithContext(ctx) + }, + expectedSubject: &ro.Subject{ + ID: "correct-user", + Type: string(infoprovider.InfoTypeUser), + }, + }, + + "should use request context even when context parameter is nil": { + contextParam: nil, + setupRequest: func() *http.Request { + ctx := context.WithValue(context.Background(), middleware.UserIDContextKey, "user456") + return httptest.NewRequest("GET", "/test", http.NoBody).WithContext(ctx) + }, + expectedSubject: &ro.Subject{ + ID: "user456", + Type: string(infoprovider.InfoTypeUser), + }, + }, + + "should return error when both contexts have no user ID": { + contextParam: context.Background(), + setupRequest: func() *http.Request { + return httptest.NewRequest("GET", "/test", http.NoBody) + }, + expectedError: "user ID not found in context", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + extractor := NewSubjectExtractor() + req := tc.setupRequest() + + result, err := extractor.Extract(tc.contextParam, req) + + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tc.expectedSubject.ID, result.ID) + assert.Equal(t, tc.expectedSubject.Type, result.Type) + } + }) + } +} diff --git a/examples/abac/internal/enforcer/operations/order_extractor.go b/examples/abac/internal/enforcer/operations/order_extractor.go new file mode 100644 index 0000000..c123689 --- /dev/null +++ b/examples/abac/internal/enforcer/operations/order_extractor.go @@ -0,0 +1,84 @@ +package operations + +import ( + "context" + "fmt" + "net/http" + "regexp" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/enforcer" + ip "github.com/CameronXie/access-control-explorer/examples/abac/internal/infoprovider" +) + +const ( + ActionRead = "read" + ActionCreate = "create" +) + +// IDExtractor extracts resource ID from HTTP request +type IDExtractor func(r *http.Request) (string, error) + +// OrderExtractorOption configures order extractor behavior +type OrderExtractorOption func(*orderExtractor) error + +type orderExtractor struct { + action string + idExtractor IDExtractor +} + +// WithIDExtractor configures ID extraction for resource-specific operations +func WithIDExtractor(extractor IDExtractor) OrderExtractorOption { + return func(e *orderExtractor) error { + e.idExtractor = extractor + return nil + } +} + +// NewOrderExtractor creates an order operation extractor +func NewOrderExtractor(action string, options ...OrderExtractorOption) (enforcer.OperationExtractor, error) { + e := &orderExtractor{ + action: action, + } + + for _, option := range options { + if err := option(e); err != nil { + return nil, fmt.Errorf("failed to configure order extractor: %w", err) + } + } + + return e, nil +} + +// Extract extracts operation details from HTTP request +func (e *orderExtractor) Extract(_ context.Context, r *http.Request) (*enforcer.Operation, error) { + operation := &enforcer.Operation{ + Action: ro.Action{ID: e.action}, + Resource: ro.Resource{Type: string(ip.InfoTypeOrder)}, + } + + // Skip ID extraction for operations that don't need it (e.g., create, list) + if e.idExtractor == nil { + return operation, nil + } + + id, err := e.idExtractor(r) + if err != nil { + return nil, fmt.Errorf("failed to extract order ID: %w", err) + } + + operation.Resource.ID = id + return operation, nil +} + +// ExtractOrderIDFromPath extracts order ID from URL path /orders/{id} +func ExtractOrderIDFromPath(r *http.Request) (string, error) { + pattern := regexp.MustCompile(`^/orders/([a-fA-F0-9-]{36})$`) + matches := pattern.FindStringSubmatch(r.URL.Path) + + if len(matches) < 2 { + return "", fmt.Errorf("path %q does not match /orders/{id} pattern", r.URL.Path) + } + + return matches[1], nil +} diff --git a/examples/abac/internal/enforcer/operations/order_extractor_test.go b/examples/abac/internal/enforcer/operations/order_extractor_test.go new file mode 100644 index 0000000..4c43b3e --- /dev/null +++ b/examples/abac/internal/enforcer/operations/order_extractor_test.go @@ -0,0 +1,285 @@ +package operations + +import ( + "context" + "net/http" + "testing" + + ip "github.com/CameronXie/access-control-explorer/examples/abac/internal/infoprovider" + "github.com/stretchr/testify/assert" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/enforcer" +) + +func TestNewOrderExtractor(t *testing.T) { + testCases := map[string]struct { + action string + options []OrderExtractorOption + expectedAction string + hasIDExtractor bool + expectedError string + }{ + "should create extractor with action only": { + action: ActionCreate, + options: nil, + expectedAction: ActionCreate, + hasIDExtractor: false, + }, + "should create extractor with action and ID extractor": { + action: ActionRead, + options: []OrderExtractorOption{ + WithIDExtractor(ExtractOrderIDFromPath), + }, + expectedAction: ActionRead, + hasIDExtractor: true, + }, + "should handle multiple options": { + action: ActionRead, + options: []OrderExtractorOption{ + WithIDExtractor(ExtractOrderIDFromPath), + WithIDExtractor(func(*http.Request) (string, error) { + return "test", nil + }), + }, + expectedAction: ActionRead, + hasIDExtractor: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + extractor, err := NewOrderExtractor(tc.action, tc.options...) + + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, extractor) + return + } + + assert.NoError(t, err) + assert.NotNil(t, extractor) + + // Verify internal state + orderExt := extractor.(*orderExtractor) + assert.Equal(t, tc.expectedAction, orderExt.action) + assert.Equal(t, tc.hasIDExtractor, orderExt.idExtractor != nil) + }) + } +} + +func TestOrderExtractor_Extract(t *testing.T) { + customExtractor := func(*http.Request) (string, error) { + return "custom-id", nil + } + + customErrExtractor := func(*http.Request) (string, error) { + return "", assert.AnError + } + + testCases := map[string]struct { + action string + options []OrderExtractorOption + requestPath string + expectedOp *enforcer.Operation + expectedError string + }{ + "should extract operation without ID for create action": { + action: ActionCreate, + requestPath: "/orders", + expectedOp: &enforcer.Operation{ + Action: ro.Action{ID: ActionCreate}, + Resource: ro.Resource{Type: string(ip.InfoTypeOrder)}, + }, + }, + + "should extract operation with UUID ID": { + action: ActionRead, + options: []OrderExtractorOption{WithIDExtractor(ExtractOrderIDFromPath)}, + requestPath: "/orders/6ba7b812-9dad-11d1-80b4-00c04fd430c8", + expectedOp: &enforcer.Operation{ + Action: ro.Action{ID: ActionRead}, + Resource: ro.Resource{Type: string(ip.InfoTypeOrder), ID: "6ba7b812-9dad-11d1-80b4-00c04fd430c8"}, + }, + }, + + "should handle ID extraction error": { + action: ActionRead, + options: []OrderExtractorOption{WithIDExtractor(ExtractOrderIDFromPath)}, + requestPath: "/orders/invalid-path/extra", + expectedError: "failed to extract order ID", + }, + + "should handle custom ID extractor success": { + action: ActionRead, + options: []OrderExtractorOption{WithIDExtractor(customExtractor)}, + requestPath: "/any/path", + expectedOp: &enforcer.Operation{ + Action: ro.Action{ID: ActionRead}, + Resource: ro.Resource{Type: string(ip.InfoTypeOrder), ID: "custom-id"}, + }, + }, + + "should handle custom ID extractor error": { + action: ActionRead, + options: []OrderExtractorOption{WithIDExtractor(customErrExtractor)}, + requestPath: "/any/path", + expectedError: "failed to extract order ID", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Set up + extractor, err := NewOrderExtractor(tc.action, tc.options...) + assert.NoError(t, err) + + // Construct request from path + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, tc.requestPath, http.NoBody) + assert.NoError(t, err) + + ctx := context.Background() + + // Execute + operation, err := extractor.Extract(ctx, req) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, operation) + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.expectedOp, operation) + }) + } +} + +func TestExtractOrderIDFromPath(t *testing.T) { + testCases := map[string]struct { + requestPath string + expectedID string + expectedError string + }{ + "should extract UUID from valid path": { + requestPath: "/orders/6ba7b812-9dad-11d1-80b4-00c04fd430c8", + expectedID: "6ba7b812-9dad-11d1-80b4-00c04fd430c8", + }, + "should extract UUID with uppercase letters": { + requestPath: "/orders/6BA7B812-9DAD-11D1-80B4-00C04FD430C8", + expectedID: "6BA7B812-9DAD-11D1-80B4-00C04FD430C8", + }, + "should extract UUID with mixed case": { + requestPath: "/orders/6ba7b812-9dad-11D1-80b4-00c04fd430c8", + expectedID: "6ba7b812-9dad-11D1-80b4-00c04fd430c8", + }, + "should fail with numeric ID": { + requestPath: "/orders/123", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with invalid path format": { + requestPath: "/orders", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with extra path segments": { + requestPath: "/orders/6ba7b812-9dad-11d1-80b4-00c04fd430c8/extra", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with wrong resource path": { + requestPath: "/users/6ba7b812-9dad-11d1-80b4-00c04fd430c8", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with invalid UUID format - too short": { + requestPath: "/orders/6ba7b812-9dad-11d1-80b4-00c04fd430c", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with invalid UUID format - too long": { + requestPath: "/orders/6ba7b812-9dad-11d1-80b4-00c04fd430c80", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with invalid UUID format - missing hyphens": { + requestPath: "/orders/6ba7b8129dad11D180b400c04fd430c8", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with invalid UUID format - wrong hyphen positions": { + requestPath: "/orders/550e84-00e29b-41d4a716-446655440000", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with empty ID": { + requestPath: "/orders/", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with special characters in UUID": { + requestPath: "/orders/6ba7b812-9dad-11d1-80b4-00c04fd430c@", + expectedError: "does not match /orders/{id} pattern", + }, + "should fail with spaces in UUID": { + requestPath: "/orders/6ba7b812-9dad-11d1-80b4-00c04fd43 000", + expectedError: "does not match /orders/{id} pattern", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Construct request from path + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, tc.requestPath, http.NoBody) + assert.NoError(t, err) + + // Execute + id, err := ExtractOrderIDFromPath(req) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Empty(t, id) + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.expectedID, id) + }) + } +} + +func TestWithIDExtractor(t *testing.T) { + testCases := map[string]struct { + extractor IDExtractor + expectedError string + }{ + "should configure ID extractor successfully": { + extractor: ExtractOrderIDFromPath, + }, + "should configure custom ID extractor successfully": { + extractor: func(*http.Request) (string, error) { + return "test-id", nil + }, + }, + "should configure nil extractor successfully": { + extractor: nil, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + option := WithIDExtractor(tc.extractor) + assert.NotNil(t, option) + + // Test that option can be applied + extractor := &orderExtractor{} + err := option(extractor) + + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + return + } + + assert.NoError(t, err) + if tc.extractor != nil { + assert.NotNil(t, extractor.idExtractor) + } else { + assert.Nil(t, extractor.idExtractor) + } + }) + } +} diff --git a/examples/abac/internal/enforcer/request_extractor.go b/examples/abac/internal/enforcer/request_extractor.go new file mode 100644 index 0000000..1e9152a --- /dev/null +++ b/examples/abac/internal/enforcer/request_extractor.go @@ -0,0 +1,167 @@ +package enforcer + +import ( + "context" + "fmt" + "net/http" + "strings" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/CameronXie/access-control-explorer/examples/abac/pkg/trie" +) + +// SubjectExtractor extracts subject information from HTTP requests +type SubjectExtractor interface { + Extract(ctx context.Context, r *http.Request) (*ro.Subject, error) +} + +// Operation represents an action and resource pair +type Operation struct { + Action ro.Action + Resource ro.Resource +} + +// OperationExtractor extracts operation information from HTTP requests +type OperationExtractor interface { + Extract(ctx context.Context, r *http.Request) (*Operation, error) +} + +// RequestExtractorOption defines configuration options for RequestExtractor +type RequestExtractorOption func(*requestExtractor) error + +type requestExtractor struct { + subjectExtractor SubjectExtractor + operationExtractorTrie *trie.Node[map[string]OperationExtractor] +} + +// normalizeMethod converts HTTP method to uppercase for consistent lookup +func normalizeMethod(method string) string { + return strings.ToUpper(method) +} + +// parsePathSegments splits URL path into segments, handling root path +func parsePathSegments(path string) []string { + pathSegments := strings.Split(strings.Trim(path, "/"), "/") + if len(pathSegments) == 1 && pathSegments[0] == "" { + return []string{} + } + return pathSegments +} + +// WithSubjectExtractor sets the subject extractor +func WithSubjectExtractor(extractor SubjectExtractor) RequestExtractorOption { + return func(re *requestExtractor) error { + if extractor == nil { + return fmt.Errorf("subject extractor cannot be nil") + } + re.subjectExtractor = extractor + return nil + } +} + +// WithOperationExtractor registers an OperationExtractor for specific path and method +func WithOperationExtractor(path, method string, extractor OperationExtractor) RequestExtractorOption { + return func(re *requestExtractor) error { + if path == "" { + return fmt.Errorf("path cannot be empty") + } + if method == "" { + return fmt.Errorf("method cannot be empty") + } + if extractor == nil { + return fmt.Errorf("operation extractor cannot be nil") + } + return re.registerOperationExtractor(path, method, extractor) + } +} + +// NewRequestExtractor creates a new RequestExtractor instance with options +func NewRequestExtractor(options ...RequestExtractorOption) (RequestExtractor, error) { + extractor := &requestExtractor{ + operationExtractorTrie: trie.New[map[string]OperationExtractor](), + } + + // Apply all options + for _, option := range options { + if err := option(extractor); err != nil { + return nil, fmt.Errorf("failed to apply option: %w", err) + } + } + + // Validate required dependencies + if extractor.subjectExtractor == nil { + return nil, fmt.Errorf("subject extractor is required") + } + + return extractor, nil +} + +// registerOperationExtractor registers an OperationExtractor for specific path and method +func (re *requestExtractor) registerOperationExtractor(path, method string, extractor OperationExtractor) error { + pathSegments := parsePathSegments(path) + normalizedMethod := normalizeMethod(method) + + // Search for existing node + node, err := re.operationExtractorTrie.Search(pathSegments) + if err != nil { + // Path doesn't exist, create it + methodMap := make(map[string]OperationExtractor) + methodMap[normalizedMethod] = extractor + return re.operationExtractorTrie.Insert(pathSegments, methodMap) + } + + // Path exists + if _, exists := node.Value[normalizedMethod]; exists { + return fmt.Errorf("method %s already registered for path %s", method, path) + } + + node.Value[normalizedMethod] = extractor + return nil +} + +// Extract extracts AccessRequest from HTTP request +func (re *requestExtractor) Extract(ctx context.Context, r *http.Request) (*ro.AccessRequest, error) { + // Extract subject + subject, err := re.subjectExtractor.Extract(ctx, r) + if err != nil { + return nil, fmt.Errorf("failed to extract subject: %w", err) + } + + // Extract operation + operation, err := re.extractOperation(ctx, r) + if err != nil { + return nil, fmt.Errorf("failed to extract operation: %w", err) + } + + return &ro.AccessRequest{ + Subject: *subject, + Action: operation.Action, + Resource: operation.Resource, + }, nil +} + +// extractOperation extracts operation from HTTP request using registered extractors +func (re *requestExtractor) extractOperation(ctx context.Context, r *http.Request) (*Operation, error) { + pathSegments := parsePathSegments(r.URL.Path) + + // Find matching extractor in trie + node, err := re.operationExtractorTrie.Search(pathSegments) + if err != nil { + return nil, fmt.Errorf("no operation extractor found for path %s: %w", r.URL.Path, err) + } + + // Get method-specific extractor + method := normalizeMethod(r.Method) + extractor, exists := node.Value[method] + if !exists { + return nil, fmt.Errorf("no operation extractor found for method %s on path %s", method, r.URL.Path) + } + + // Extract operation + operation, err := extractor.Extract(ctx, r) + if err != nil { + return nil, fmt.Errorf("operation extraction failed: %w", err) + } + + return operation, nil +} diff --git a/examples/abac/internal/enforcer/request_extractor_test.go b/examples/abac/internal/enforcer/request_extractor_test.go new file mode 100644 index 0000000..3da4ab7 --- /dev/null +++ b/examples/abac/internal/enforcer/request_extractor_test.go @@ -0,0 +1,421 @@ +package enforcer + +import ( + "context" + "errors" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/CameronXie/access-control-explorer/examples/abac/pkg/trie" +) + +// Mock implementations +type mockSubjectExtractor struct { + mock.Mock +} + +func (m *mockSubjectExtractor) Extract(ctx context.Context, r *http.Request) (*ro.Subject, error) { + args := m.Called(ctx, r) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*ro.Subject), args.Error(1) +} + +type mockOperationExtractor struct { + mock.Mock +} + +func (m *mockOperationExtractor) Extract(ctx context.Context, r *http.Request) (*Operation, error) { + args := m.Called(ctx, r) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*Operation), args.Error(1) +} + +func TestNewRequestExtractor(t *testing.T) { + testCases := map[string]struct { + options []RequestExtractorOption + expectedError string + }{ + "should fail when no subject extractor provided": { + options: []RequestExtractorOption{}, + expectedError: "subject extractor is required", + }, + + "should succeed with valid subject extractor": { + options: []RequestExtractorOption{ + WithSubjectExtractor(&mockSubjectExtractor{}), + }, + }, + + "should succeed with subject and operation extractors": { + options: []RequestExtractorOption{ + WithSubjectExtractor(&mockSubjectExtractor{}), + WithOperationExtractor("/users", "GET", &mockOperationExtractor{}), + }, + }, + + "should fail when empty path provided": { + options: []RequestExtractorOption{ + WithSubjectExtractor(&mockSubjectExtractor{}), + WithOperationExtractor("", "GET", &mockOperationExtractor{}), + }, + expectedError: "path cannot be empty", + }, + + "should fail when empty method provided": { + options: []RequestExtractorOption{ + WithSubjectExtractor(&mockSubjectExtractor{}), + WithOperationExtractor("/users", "", &mockOperationExtractor{}), + }, + expectedError: "method cannot be empty", + }, + + "should fail when nil operation extractor provided": { + options: []RequestExtractorOption{ + WithSubjectExtractor(&mockSubjectExtractor{}), + WithOperationExtractor("/users", "GET", nil), + }, + expectedError: "operation extractor cannot be nil", + }, + + "should fail when duplicate path and method registered": { + options: []RequestExtractorOption{ + WithSubjectExtractor(&mockSubjectExtractor{}), + WithOperationExtractor("/users", "GET", &mockOperationExtractor{}), + WithOperationExtractor("/users", "GET", &mockOperationExtractor{}), + }, + expectedError: "method GET already registered for path /users", + }, + + "should succeed with multiple different paths": { + options: []RequestExtractorOption{ + WithSubjectExtractor(&mockSubjectExtractor{}), + WithOperationExtractor("/users", "GET", &mockOperationExtractor{}), + WithOperationExtractor("/users", "POST", &mockOperationExtractor{}), + WithOperationExtractor("/admin", "DELETE", &mockOperationExtractor{}), + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + extractor, err := NewRequestExtractor(tc.options...) + + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, extractor) + } else { + assert.NoError(t, err) + assert.NotNil(t, extractor) + } + }) + } +} + +func TestRequestExtractor_Extract(t *testing.T) { + testCases := map[string]struct { + subject *ro.Subject + subjectError error + opExtractorPath string + opExtractorMethod string + operation *Operation + operationError error + shouldExtractOperation bool + request *http.Request + expectedResult *ro.AccessRequest + expectedError string + }{ + "should successfully extract access request": { + subject: &ro.Subject{ + ID: "user123", + Type: "users", + }, + operation: &Operation{ + Action: ro.Action{ID: "read"}, + Resource: ro.Resource{Type: "documents"}, + }, + opExtractorPath: "/documents", + opExtractorMethod: http.MethodGet, + shouldExtractOperation: true, + request: createTestRequest("GET", "/documents"), + expectedResult: &ro.AccessRequest{ + Subject: ro.Subject{ + ID: "user123", + Type: "users", + }, + Action: ro.Action{ID: "read"}, + Resource: ro.Resource{Type: "documents"}, + }, + }, + + "should fail when subject extraction fails": { + subjectError: errors.New("subject extraction failed"), + operation: &Operation{ + Action: ro.Action{ID: "read"}, + Resource: ro.Resource{Type: "documents"}, + }, + opExtractorPath: "/documents", + opExtractorMethod: http.MethodGet, + request: createTestRequest("GET", "/documents"), + expectedError: "failed to extract subject", + }, + + "should fail when operation extraction fails": { + subject: &ro.Subject{ + ID: "user123", + Type: "users", + }, + operationError: errors.New("operation extraction failed"), + opExtractorPath: "/documents", + opExtractorMethod: http.MethodGet, + shouldExtractOperation: true, + request: createTestRequest("GET", "/documents"), + expectedError: "failed to extract operation", + }, + + "should fail when no operation extractor found for path": { + subject: &ro.Subject{ + ID: "user123", + Type: "users", + }, + operation: &Operation{ + Action: ro.Action{ID: "read"}, + Resource: ro.Resource{Type: "documents"}, + }, + opExtractorPath: "/documents", + opExtractorMethod: http.MethodGet, + request: createTestRequest("GET", "/unknown"), + expectedError: "no operation extractor found for path /unknown", + }, + + "should fail when no operation extractor found for method": { + subject: &ro.Subject{ + ID: "user123", + Type: "users", + }, + operation: &Operation{ + Action: ro.Action{ID: "read"}, + Resource: ro.Resource{Type: "documents"}, + }, + opExtractorPath: "/documents", + opExtractorMethod: http.MethodGet, + request: createTestRequest(http.MethodPost, "/documents"), + expectedError: "no operation extractor found for method POST on path /documents", + }, + + "should handle complex paths with parameters": { + subject: &ro.Subject{ + ID: "user123", + Type: "users", + }, + operation: &Operation{ + Action: ro.Action{ID: "read"}, + Resource: ro.Resource{ID: "document123", Type: "documents"}, + }, + opExtractorPath: "/documents/*", + opExtractorMethod: http.MethodGet, + shouldExtractOperation: true, + request: createTestRequest("GET", "/documents/document123"), + expectedResult: &ro.AccessRequest{ + Subject: ro.Subject{ + ID: "user123", + Type: "users", + }, + Action: ro.Action{ID: "read"}, + Resource: ro.Resource{ + ID: "document123", + Type: "documents", + }, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Setup mocks + subjectExtractor := &mockSubjectExtractor{} + subjectExtractor.On("Extract", mock.Anything, tc.request).Return(tc.subject, tc.subjectError) + + opExtractor := &mockOperationExtractor{} + + if tc.shouldExtractOperation { + opExtractor.On("Extract", mock.Anything, tc.request).Return(tc.operation, tc.operationError) + } + + extractor, err := NewRequestExtractor( + WithSubjectExtractor(subjectExtractor), + WithOperationExtractor(tc.opExtractorPath, tc.opExtractorMethod, opExtractor), + ) + require.NoError(t, err) + + // Execute + result, err := extractor.Extract(context.Background(), tc.request) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedResult, result) + } + + // Verify mocks + subjectExtractor.AssertExpectations(t) + opExtractor.AssertExpectations(t) + }) + } +} + +func TestWithSubjectExtractor(t *testing.T) { + testCases := map[string]struct { + extractor SubjectExtractor + expectedError string + }{ + "should succeed with valid extractor": { + extractor: &mockSubjectExtractor{}, + }, + "should fail with nil extractor": { + extractor: nil, + expectedError: "subject extractor cannot be nil", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + re := &requestExtractor{} + option := WithSubjectExtractor(tc.extractor) + err := option(re) + + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.extractor, re.subjectExtractor) + } + }) + } +} + +func TestWithOperationExtractor(t *testing.T) { + testCases := map[string]struct { + path string + method string + extractor OperationExtractor + expectedError string + }{ + "should succeed with valid parameters": { + path: "/users", + method: "GET", + extractor: &mockOperationExtractor{}, + }, + + "should fail with empty path": { + path: "", + method: "GET", + extractor: &mockOperationExtractor{}, + expectedError: "path cannot be empty", + }, + + "should succeed with root path": { + path: "/", + method: "GET", + extractor: &mockOperationExtractor{}, + }, + + "should fail with empty method": { + path: "/users", + method: "", + extractor: &mockOperationExtractor{}, + expectedError: "method cannot be empty", + }, + + "should fail with nil extractor": { + path: "/users", + method: "GET", + extractor: nil, + expectedError: "operation extractor cannot be nil", + }, + + "should succeed with wildcard path": { + path: "/users/*/profile", + method: "GET", + extractor: &mockOperationExtractor{}, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + re := &requestExtractor{ + operationExtractorTrie: trie.New[map[string]OperationExtractor](), + } + option := WithOperationExtractor(tc.path, tc.method, tc.extractor) + err := option(re) + + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + return + } + + assert.NoError(t, err) + n, err := re.operationExtractorTrie.Search(parsePathSegments(tc.path)) + assert.NoError(t, err) + assert.Equal(t, tc.extractor, n.Value[tc.method]) + }) + } +} + +func TestNormalizeMethod(t *testing.T) { + testCases := map[string]struct { + method string + expected string + }{ + "should normalize lowercase": { + method: "get", + expected: "GET", + }, + "should normalize mixed case": { + method: "PoSt", + expected: "POST", + }, + "should keep uppercase": { + method: "DELETE", + expected: "DELETE", + }, + "should handle empty string": { + method: "", + expected: "", + }, + "should handle special methods": { + method: "patch", + expected: "PATCH", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result := normalizeMethod(tc.method) + assert.Equal(t, tc.expected, result) + }) + } +} + +// Helper function to create HTTP requests +func createTestRequest(method, path string) *http.Request { + req := &http.Request{ + Method: method, + URL: &url.URL{ + Path: path, + }, + } + return req +} diff --git a/examples/abac/internal/infoprovider/infoprovider.go b/examples/abac/internal/infoprovider/infoprovider.go new file mode 100644 index 0000000..088388e --- /dev/null +++ b/examples/abac/internal/infoprovider/infoprovider.go @@ -0,0 +1,44 @@ +package infoprovider + +import ( + "context" + "fmt" + + ip "github.com/CameronXie/access-control-explorer/abac/infoprovider" +) + +// InfoType is the key under which a provider is registered. +type InfoType string + +const ( + InfoTypeUser InfoType = "user" + InfoTypeOrder InfoType = "order" + InfoTypeRBAC InfoType = "rbac" +) + +// infoProvider manages a collection of InfoProvider implementations mapped by type. +// It directs requests to the appropriate provider based on the request type. +type infoProvider struct { + providers map[InfoType]ip.InfoProvider +} + +// NewInfoProvider creates and returns an InfoProvider instance that routes requests based on their type using the given map. +func NewInfoProvider(providers map[InfoType]ip.InfoProvider) ip.InfoProvider { + return &infoProvider{ + providers: providers, + } +} + +// GetInfo routes the request to the appropriate provider based on req.Type +func (p *infoProvider) GetInfo(ctx context.Context, req *ip.GetInfoRequest) (*ip.GetInfoResponse, error) { + if req == nil { + return nil, fmt.Errorf("request cannot be nil") + } + + provider, ok := p.providers[InfoType(req.InfoType)] + if !ok { + return nil, fmt.Errorf("unsupported info type %s", req.InfoType) + } + + return provider.GetInfo(ctx, req) +} diff --git a/examples/abac/internal/infoprovider/infoprovider_test.go b/examples/abac/internal/infoprovider/infoprovider_test.go new file mode 100644 index 0000000..2753e3f --- /dev/null +++ b/examples/abac/internal/infoprovider/infoprovider_test.go @@ -0,0 +1,158 @@ +package infoprovider + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + ip "github.com/CameronXie/access-control-explorer/abac/infoprovider" +) + +type testContextKey string + +// mockInfoProvider is a mock implementation of InfoProvider for testing +type mockInfoProvider struct { + mock.Mock +} + +func (m *mockInfoProvider) GetInfo(ctx context.Context, req *ip.GetInfoRequest) (*ip.GetInfoResponse, error) { + args := m.Called(ctx, req) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + res := args.Get(0).(*ip.GetInfoResponse) + + if ctx.Value(testContextKey("test")) != nil { + res.Info["test"] = ctx.Value(testContextKey("test")) + } + + return res, args.Error(1) +} + +func TestInfoProvider_GetInfo(t *testing.T) { + testCases := map[string]struct { + req *ip.GetInfoRequest + mockInfoProviderResp *ip.GetInfoResponse + mockInfoProviderErr error + setupContext func() context.Context + expectedResult *ip.GetInfoResponse + expectedError string + }{ + "should return error when request is nil": { + req: nil, + setupContext: func() context.Context { return context.Background() }, + expectedError: "request cannot be nil", + expectedResult: nil, + }, + + "should return info when info provider exists": { + req: &ip.GetInfoRequest{ + InfoType: "user", + Params: "user123", + }, + mockInfoProviderResp: &ip.GetInfoResponse{ + Info: map[string]any{ + "id": "user123", + "name": "John Doe", + "department": "engineering", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: &ip.GetInfoResponse{ + Info: map[string]any{ + "id": "user123", + "name": "John Doe", + "department": "engineering", + }, + }, + }, + + "should return error when unsupported info type is requested": { + req: &ip.GetInfoRequest{ + InfoType: "unsupported", + Params: "param123", + }, + setupContext: func() context.Context { return context.Background() }, + expectedError: "unsupported info type unsupported", + expectedResult: nil, + }, + + "should return error when provider returns error": { + req: &ip.GetInfoRequest{ + InfoType: "user", + Params: "invalidUser", + }, + mockInfoProviderErr: fmt.Errorf("user invalidUser not found"), + setupContext: func() context.Context { return context.Background() }, + expectedError: "user invalidUser not found", + expectedResult: nil, + }, + + "should pass context to underlying provider": { + req: &ip.GetInfoRequest{ + InfoType: "user", + Params: "contextTest", + }, + mockInfoProviderResp: &ip.GetInfoResponse{ + Info: map[string]any{ + "id": "user123", + "name": "John Doe", + "department": "engineering", + }, + }, + setupContext: func() context.Context { + return context.WithValue(context.Background(), testContextKey("test"), "value") + }, + expectedResult: &ip.GetInfoResponse{ + Info: map[string]any{ + "id": "user123", + "name": "John Doe", + "department": "engineering", + "test": "value", + }, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Create mock providers + userProvider := new(mockInfoProvider) + + // Setup context + ctx := tc.setupContext() + + // Setup mocks + if tc.req != nil && tc.req.InfoType == "user" { + userProvider.On("GetInfo", ctx, tc.req).Return( + tc.mockInfoProviderResp, + tc.mockInfoProviderErr, + ) + } + + // Create an info provider + p := NewInfoProvider(map[InfoType]ip.InfoProvider{ + "user": userProvider, + }) + + // Execute + result, err := p.GetInfo(ctx, tc.req) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Equal(t, tc.expectedResult, result) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedResult, result) + } + + // Verify mock expectations + userProvider.AssertExpectations(t) + }) + } +} diff --git a/examples/abac/internal/infoprovider/order_provider.go b/examples/abac/internal/infoprovider/order_provider.go new file mode 100644 index 0000000..8d64d0b --- /dev/null +++ b/examples/abac/internal/infoprovider/order_provider.go @@ -0,0 +1,55 @@ +package infoprovider + +import ( + "context" + "fmt" + + "github.com/google/uuid" + + ip "github.com/CameronXie/access-control-explorer/abac/infoprovider" +) + +// OrderAttributesRepository defines the contract for order attribute operations +type OrderAttributesRepository interface { + GetOrderAttributesByID(ctx context.Context, id uuid.UUID) (map[string]any, error) +} + +// orderProvider implements InfoProvider for order data +type orderProvider struct { + orderRepo OrderAttributesRepository +} + +// NewOrderProvider creates a new order info provider with dependency injection +func NewOrderProvider(orderRepo OrderAttributesRepository) ip.InfoProvider { + return &orderProvider{ + orderRepo: orderRepo, + } +} + +// GetInfo retrieves order attributes based on the provided request containing an order ID. +func (p *orderProvider) GetInfo(ctx context.Context, req *ip.GetInfoRequest) (*ip.GetInfoResponse, error) { + if req == nil { + return nil, fmt.Errorf("request cannot be nil") + } + + orderIDStr, ok := req.Params.(string) + if !ok { + return nil, fmt.Errorf("order ID parameter must be a string, got %T: %v", req.Params, req.Params) + } + + if orderIDStr == "" { + return &ip.GetInfoResponse{Info: map[string]any{}}, nil + } + + orderID, err := uuid.Parse(orderIDStr) + if err != nil { + return nil, fmt.Errorf("order ID must be a valid UUID format, got: %s", orderIDStr) + } + + attrs, err := p.orderRepo.GetOrderAttributesByID(ctx, orderID) + if err != nil { + return nil, err + } + + return &ip.GetInfoResponse{Info: attrs}, nil +} diff --git a/examples/abac/internal/infoprovider/order_provider_test.go b/examples/abac/internal/infoprovider/order_provider_test.go new file mode 100644 index 0000000..2bf992f --- /dev/null +++ b/examples/abac/internal/infoprovider/order_provider_test.go @@ -0,0 +1,216 @@ +//nolint:dupl // Similar structure to user_provider by design: separate domain providers share flow now. +package infoprovider + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + ip "github.com/CameronXie/access-control-explorer/abac/infoprovider" +) + +// mockOrderAttributesRepository is a mock implementation of OrderAttributesRepository +type mockOrderAttributesRepository struct { + mock.Mock +} + +func (m *mockOrderAttributesRepository) GetOrderAttributesByID(ctx context.Context, id uuid.UUID) (map[string]any, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(map[string]any), args.Error(1) +} + +func TestNewOrderProvider(t *testing.T) { + mockRepo := new(mockOrderAttributesRepository) + provider := NewOrderProvider(mockRepo) + + assert.NotNil(t, provider) + assert.IsType(t, &orderProvider{}, provider) +} + +func TestOrderProvider_GetInfo(t *testing.T) { + orderID := uuid.New() + anotherOrderID := uuid.New() + + testCases := map[string]struct { + request *ip.GetInfoRequest + mockOrderResp map[string]any + mockOrderErr error + expectedResponse *ip.GetInfoResponse + expectedError string + shouldCallMock bool + expectedOrderID uuid.UUID + }{ + "should return order attributes when valid order ID is provided": { + request: &ip.GetInfoRequest{ + Params: orderID.String(), + }, + mockOrderResp: map[string]any{ + "category": "premium", + "price": 1999.99, + "currency": "USD", + "user_id": "user_123", + "total_items": 5, + "status": "processing", + "shipping": map[string]any{ + "address": "123 Main St", + "city": "New York", + "urgent": true, + }, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "category": "premium", + "price": 1999.99, + "currency": "USD", + "user_id": "user_123", + "total_items": 5, + "status": "processing", + "shipping": map[string]any{ + "address": "123 Main St", + "city": "New York", + "urgent": true, + }, + }, + }, + shouldCallMock: true, + expectedOrderID: orderID, + }, + + "should return empty attributes when order has no attributes": { + request: &ip.GetInfoRequest{ + Params: orderID.String(), + }, + mockOrderResp: map[string]any{}, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{}, + }, + shouldCallMock: true, + expectedOrderID: orderID, + }, + + "should return error when request is nil": { + request: nil, + expectedError: "request cannot be nil", + shouldCallMock: false, + }, + + "should return error when params is not string": { + request: &ip.GetInfoRequest{ + Params: 12345, + }, + expectedError: "order ID parameter must be a string, got int: 12345", + shouldCallMock: false, + }, + + "should return error when params is not valid UUID": { + request: &ip.GetInfoRequest{ + Params: "invalid-uuid", + }, + expectedError: "order ID must be a valid UUID format, got: invalid-uuid", + shouldCallMock: false, + }, + + "should return error when repository returns NotFoundError": { + request: &ip.GetInfoRequest{ + Params: anotherOrderID.String(), + }, + mockOrderErr: &repository.NotFoundError{ + Resource: "order", + Key: "id", + Value: anotherOrderID.String(), + }, + expectedError: fmt.Sprintf("order with id %s not found", anotherOrderID.String()), + shouldCallMock: true, + expectedOrderID: anotherOrderID, + }, + + "should return error when repository returns database error": { + request: &ip.GetInfoRequest{ + Params: orderID.String(), + }, + mockOrderErr: errors.New("database connection failed"), + expectedError: "database connection failed", + shouldCallMock: true, + expectedOrderID: orderID, + }, + + "should handle different UUID formats": { + request: &ip.GetInfoRequest{ + Params: orderID.String(), + }, + mockOrderResp: map[string]any{ + "category": "standard", + "price": 99.99, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "category": "standard", + "price": 99.99, + }, + }, + shouldCallMock: true, + expectedOrderID: orderID, + }, + + "should handle params as interface{} containing string": { + request: &ip.GetInfoRequest{ + Params: any(orderID.String()), + }, + mockOrderResp: map[string]any{ + "category": "test", + "price": 49.99, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "category": "test", + "price": 49.99, + }, + }, + shouldCallMock: true, + expectedOrderID: orderID, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Create mock repository + mockRepo := new(mockOrderAttributesRepository) + + // Setup mocks in test case loop + if tc.shouldCallMock { + mockRepo.On("GetOrderAttributesByID", mock.Anything, tc.expectedOrderID).Return( + tc.mockOrderResp, + tc.mockOrderErr, + ) + } + + provider := NewOrderProvider(mockRepo) + + // Execute + response, err := provider.GetInfo(context.Background(), tc.request) + + // Assert + if tc.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, response) + } else { + assert.NoError(t, err) + assert.EqualValues(t, tc.expectedResponse, response) + } + + // Verify mock expectations + mockRepo.AssertExpectations(t) + }) + } +} diff --git a/examples/abac/internal/infoprovider/rbac_provider.go b/examples/abac/internal/infoprovider/rbac_provider.go new file mode 100644 index 0000000..f5bc012 --- /dev/null +++ b/examples/abac/internal/infoprovider/rbac_provider.go @@ -0,0 +1,107 @@ +package infoprovider + +import ( + "context" + "fmt" + "strings" + + ip "github.com/CameronXie/access-control-explorer/abac/infoprovider" +) + +// Permission is a role permission with optional conditions. +type Permission struct { + ActionName string `json:"action"` + ResourceName string `json:"resource"` + Conditions []PermissionCondition `json:"conditions,omitempty"` +} + +// PermissionCondition is a conditional constraint on a permission. +type PermissionCondition struct { + AttributeKey string `json:"attribute_key"` + Operator string `json:"operator"` + AttributeValue any `json:"attribute_value"` +} + +// RoleHierarchy contains requested roles and their descendants. +type RoleHierarchy struct { + RequestedRoles []string `json:"requested_roles"` + Descendants []string `json:"descendants"` +} + +// RBACRepository is the read-only contract this provider needs. +type RBACRepository interface { + GetRoleDescendants(ctx context.Context, rootRoles []string) ([]string, error) + GetPermissionsByRoles(ctx context.Context, roles []string) (map[string][]Permission, error) +} + +type roleBasedAccessProvider struct { + repo RBACRepository +} + +// NewRoleBasedAccessProvider creates a storage-agnostic RBAC info provider. +func NewRoleBasedAccessProvider(repo RBACRepository) ip.InfoProvider { + return &roleBasedAccessProvider{repo: repo} +} + +// GetInfo retrieves role hierarchy and permissions for the provided roles. +func (p *roleBasedAccessProvider) GetInfo(ctx context.Context, req *ip.GetInfoRequest) (*ip.GetInfoResponse, error) { + if req == nil { + return nil, fmt.Errorf("request cannot be nil") + } + + // Expecting []string in Params (can come as []any from JSON decode paths, normalize here). + switch v := req.Params.(type) { + case []string: + return p.handle(ctx, v) + case []any: + roleNames := make([]string, 0, len(v)) + for _, r := range v { + s, ok := r.(string) + if !ok { + return nil, fmt.Errorf("role names must be []string") + } + roleNames = append(roleNames, s) + } + return p.handle(ctx, roleNames) + default: + return nil, fmt.Errorf("role names parameter must be a []string, got %T", req.Params) + } +} + +func (p *roleBasedAccessProvider) handle(ctx context.Context, roleNames []string) (*ip.GetInfoResponse, error) { + // Normalize and ensure non-empty + normalized := make([]string, 0, len(roleNames)) + for _, r := range roleNames { + if s := strings.TrimSpace(r); s != "" { + normalized = append(normalized, s) + } + } + if len(normalized) == 0 { + return nil, fmt.Errorf("at least one role name must be provided") + } + + // Get all descendant roles (including the roots) + descendants, err := p.repo.GetRoleDescendants(ctx, normalized) + if err != nil { + return nil, fmt.Errorf("failed to get role descendants for roles %v: %w", normalized, err) + } + if len(descendants) == 0 { + return nil, fmt.Errorf("none of the requested roles were found: %v", normalized) + } + + // Get permissions for all roles in the hierarchy + perms, err := p.repo.GetPermissionsByRoles(ctx, descendants) + if err != nil { + return nil, fmt.Errorf("failed to get permissions for roles %v: %w", descendants, err) + } + + return &ip.GetInfoResponse{ + Info: map[string]any{ + "role_hierarchy": RoleHierarchy{ + RequestedRoles: normalized, + Descendants: descendants, + }, + "role_permissions": perms, + }, + }, nil +} diff --git a/examples/abac/internal/infoprovider/rbac_provider_test.go b/examples/abac/internal/infoprovider/rbac_provider_test.go new file mode 100644 index 0000000..94f6654 --- /dev/null +++ b/examples/abac/internal/infoprovider/rbac_provider_test.go @@ -0,0 +1,326 @@ +package infoprovider + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + ip "github.com/CameronXie/access-control-explorer/abac/infoprovider" +) + +// mockRBACRepository is a mock implementation of RBACRepository +type mockRBACRepository struct { + mock.Mock +} + +func (m *mockRBACRepository) GetRoleDescendants(ctx context.Context, rootRoles []string) ([]string, error) { + args := m.Called(ctx, rootRoles) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]string), args.Error(1) +} + +func (m *mockRBACRepository) GetPermissionsByRoles(ctx context.Context, roles []string) (map[string][]Permission, error) { + args := m.Called(ctx, roles) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(map[string][]Permission), args.Error(1) +} + +func TestNewRoleBasedAccessProvider(t *testing.T) { + mockRepo := new(mockRBACRepository) + provider := NewRoleBasedAccessProvider(mockRepo) + + assert.NotNil(t, provider) + assert.IsType(t, &roleBasedAccessProvider{}, provider) +} + +func TestRoleBasedAccessProvider_GetInfo(t *testing.T) { + testCases := map[string]struct { + request *ip.GetInfoRequest + mockDescendantsResp []string + mockDescendantsErr error + mockPermissionsResp map[string][]Permission + mockPermissionsErr error + expectedResponse *ip.GetInfoResponse + expectedError string + shouldCallDescendants bool + shouldCallPermissions bool + expectedDescendantsParam []string + expectedPermissionsParam []string + }{ + "should handle complex permissions with multiple conditions": { + request: &ip.GetInfoRequest{ + Params: []string{"security_admin"}, + }, + mockDescendantsResp: []string{"security_admin", "admin", "user"}, + mockPermissionsResp: map[string][]Permission{ + "security_admin": { + { + ActionName: "access", + ResourceName: "security_logs", + Conditions: []PermissionCondition{ + { + AttributeKey: "clearance_level", + Operator: "gte", + AttributeValue: "level5", + }, + { + AttributeKey: "department", + Operator: "in", + AttributeValue: []string{"security", "compliance"}, + }, + }, + }, + { + ActionName: "modify", + ResourceName: "user_permissions", + Conditions: []PermissionCondition{ + { + AttributeKey: "target_user_level", + Operator: "lt", + AttributeValue: "admin", + }, + }, + }, + }, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "role_hierarchy": RoleHierarchy{ + RequestedRoles: []string{"security_admin"}, + Descendants: []string{"security_admin", "admin", "user"}, + }, + "role_permissions": map[string][]Permission{ + "security_admin": { + { + ActionName: "access", + ResourceName: "security_logs", + Conditions: []PermissionCondition{ + { + AttributeKey: "clearance_level", + Operator: "gte", + AttributeValue: "level5", + }, + { + AttributeKey: "department", + Operator: "in", + AttributeValue: []string{"security", "compliance"}, + }, + }, + }, + { + ActionName: "modify", + ResourceName: "user_permissions", + Conditions: []PermissionCondition{ + { + AttributeKey: "target_user_level", + Operator: "lt", + AttributeValue: "admin", + }, + }, + }, + }, + }, + }, + }, + shouldCallDescendants: true, + shouldCallPermissions: true, + expectedDescendantsParam: []string{"security_admin"}, + expectedPermissionsParam: []string{"security_admin", "admin", "user"}, + }, + + "should handle []any params and convert to []string": { + request: &ip.GetInfoRequest{ + Params: []any{"admin", "manager"}, + }, + mockDescendantsResp: []string{"admin", "manager", "user"}, + mockPermissionsResp: map[string][]Permission{ + "admin": { + { + ActionName: "delete", + ResourceName: "order", + Conditions: []PermissionCondition{}, + }, + }, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "role_hierarchy": RoleHierarchy{ + RequestedRoles: []string{"admin", "manager"}, + Descendants: []string{"admin", "manager", "user"}, + }, + "role_permissions": map[string][]Permission{ + "admin": { + { + ActionName: "delete", + ResourceName: "order", + Conditions: []PermissionCondition{}, + }, + }, + }, + }, + }, + shouldCallDescendants: true, + shouldCallPermissions: true, + expectedDescendantsParam: []string{"admin", "manager"}, + expectedPermissionsParam: []string{"admin", "manager", "user"}, + }, + + "should handle roles with whitespace by trimming": { + request: &ip.GetInfoRequest{ + Params: []string{" admin ", " manager ", " user"}, + }, + mockDescendantsResp: []string{"admin", "manager", "user"}, + mockPermissionsResp: map[string][]Permission{}, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "role_hierarchy": RoleHierarchy{ + RequestedRoles: []string{"admin", "manager", "user"}, + Descendants: []string{"admin", "manager", "user"}, + }, + "role_permissions": map[string][]Permission{}, + }, + }, + shouldCallDescendants: true, + shouldCallPermissions: true, + expectedDescendantsParam: []string{"admin", "manager", "user"}, + expectedPermissionsParam: []string{"admin", "manager", "user"}, + }, + + "should return error when request is nil": { + request: nil, + expectedError: "request cannot be nil", + shouldCallDescendants: false, + shouldCallPermissions: false, + }, + + "should return error when params is not []string or []any": { + request: &ip.GetInfoRequest{ + Params: "invalid-params", + }, + expectedError: "role names parameter must be a []string, got string", + shouldCallDescendants: false, + shouldCallPermissions: false, + }, + + "should return error when []any contains non-string elements": { + request: &ip.GetInfoRequest{ + Params: []any{"admin", 123, "user"}, + }, + expectedError: "role names must be []string", + shouldCallDescendants: false, + shouldCallPermissions: false, + }, + + "should return error when all roles are empty after trimming": { + request: &ip.GetInfoRequest{ + Params: []string{"", " ", "\t\n"}, + }, + expectedError: "at least one role name must be provided", + shouldCallDescendants: false, + shouldCallPermissions: false, + }, + + "should return error when no role descendants found": { + request: &ip.GetInfoRequest{ + Params: []string{"nonexistent_role"}, + }, + mockDescendantsResp: []string{}, + expectedError: "none of the requested roles were found: [nonexistent_role]", + shouldCallDescendants: true, + shouldCallPermissions: false, + expectedDescendantsParam: []string{"nonexistent_role"}, + }, + + "should return error when GetRoleDescendants fails": { + request: &ip.GetInfoRequest{ + Params: []string{"admin"}, + }, + mockDescendantsErr: errors.New("database connection failed"), + expectedError: "failed to get role descendants for roles [admin]: database connection failed", + shouldCallDescendants: true, + shouldCallPermissions: false, + expectedDescendantsParam: []string{"admin"}, + }, + + "should return error when GetPermissionsByRoles fails": { + request: &ip.GetInfoRequest{ + Params: []string{"admin"}, + }, + mockDescendantsResp: []string{"admin", "user"}, + mockPermissionsErr: errors.New("permission query failed"), + expectedError: "failed to get permissions for roles [admin user]: permission query failed", + shouldCallDescendants: true, + shouldCallPermissions: true, + expectedDescendantsParam: []string{"admin"}, + expectedPermissionsParam: []string{"admin", "user"}, + }, + + "should handle empty permissions map": { + request: &ip.GetInfoRequest{ + Params: []string{"readonly_user"}, + }, + mockDescendantsResp: []string{"readonly_user"}, + mockPermissionsResp: map[string][]Permission{}, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "role_hierarchy": RoleHierarchy{ + RequestedRoles: []string{"readonly_user"}, + Descendants: []string{"readonly_user"}, + }, + "role_permissions": map[string][]Permission{}, + }, + }, + shouldCallDescendants: true, + shouldCallPermissions: true, + expectedDescendantsParam: []string{"readonly_user"}, + expectedPermissionsParam: []string{"readonly_user"}, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Create mock repository + mockRepo := new(mockRBACRepository) + + // Setup mocks in test case loop + if tc.shouldCallDescendants { + mockRepo.On("GetRoleDescendants", mock.Anything, tc.expectedDescendantsParam).Return( + tc.mockDescendantsResp, + tc.mockDescendantsErr, + ) + } + if tc.shouldCallPermissions { + mockRepo.On("GetPermissionsByRoles", mock.Anything, tc.expectedPermissionsParam).Return( + tc.mockPermissionsResp, + tc.mockPermissionsErr, + ) + } + + provider := NewRoleBasedAccessProvider(mockRepo) + + // Execute + response, err := provider.GetInfo(context.Background(), tc.request) + + // Assert + if tc.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, response) + } else { + assert.NoError(t, err) + assert.EqualValues(t, tc.expectedResponse, response) + } + + // Verify mock expectations + mockRepo.AssertExpectations(t) + }) + } +} diff --git a/examples/abac/internal/infoprovider/user_provider.go b/examples/abac/internal/infoprovider/user_provider.go new file mode 100644 index 0000000..743b29e --- /dev/null +++ b/examples/abac/internal/infoprovider/user_provider.go @@ -0,0 +1,51 @@ +package infoprovider + +import ( + "context" + "fmt" + + ip "github.com/CameronXie/access-control-explorer/abac/infoprovider" + "github.com/google/uuid" +) + +// UserAttributesRepository defines the contract for user attribute operations +type UserAttributesRepository interface { + GetUserAttributesByID(ctx context.Context, id uuid.UUID) (map[string]any, error) +} + +// userProvider implements InfoProvider for user data +type userProvider struct { + userRepo UserAttributesRepository +} + +// NewUserProvider creates a new user info provider with dependency injection +func NewUserProvider(userRepo UserAttributesRepository) ip.InfoProvider { + return &userProvider{ + userRepo: userRepo, + } +} + +// GetInfo retrieves user attributes based on the provided request containing a user ID. +// Returns attributes with roles guaranteed to be []string type. +func (p *userProvider) GetInfo(ctx context.Context, req *ip.GetInfoRequest) (*ip.GetInfoResponse, error) { + if req == nil { + return nil, fmt.Errorf("request cannot be nil") + } + + userIDStr, ok := req.Params.(string) + if !ok { + return nil, fmt.Errorf("user ID parameter must be a string, got %T: %v", req.Params, req.Params) + } + + userID, err := uuid.Parse(userIDStr) + if err != nil { + return nil, fmt.Errorf("user ID must be a valid UUID format, got: %s", userIDStr) + } + + attrs, err := p.userRepo.GetUserAttributesByID(ctx, userID) + if err != nil { + return nil, err + } + + return &ip.GetInfoResponse{Info: attrs}, nil +} diff --git a/examples/abac/internal/infoprovider/user_provider_test.go b/examples/abac/internal/infoprovider/user_provider_test.go new file mode 100644 index 0000000..8a5938f --- /dev/null +++ b/examples/abac/internal/infoprovider/user_provider_test.go @@ -0,0 +1,294 @@ +//nolint:dupl // Similar structure to order_provider by design: separate domain providers share flow now. +package infoprovider + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + ip "github.com/CameronXie/access-control-explorer/abac/infoprovider" +) + +// mockUserAttributesRepository is a mock implementation of UserAttributesRepository +type mockUserAttributesRepository struct { + mock.Mock +} + +func (m *mockUserAttributesRepository) GetUserAttributesByID(ctx context.Context, id uuid.UUID) (map[string]any, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(map[string]any), args.Error(1) +} + +func TestNewUserProvider(t *testing.T) { + mockRepo := new(mockUserAttributesRepository) + provider := NewUserProvider(mockRepo) + + assert.NotNil(t, provider) + assert.IsType(t, &userProvider{}, provider) +} + +func TestUserProvider_GetInfo(t *testing.T) { + userID := uuid.New() + anotherUserID := uuid.New() + + testCases := map[string]struct { + request *ip.GetInfoRequest + mockUserResp map[string]any + mockUserErr error + expectedResponse *ip.GetInfoResponse + expectedError string + shouldCallMock bool + expectedUserID uuid.UUID + }{ + "should return user attributes with roles as string array": { + request: &ip.GetInfoRequest{ + Params: userID.String(), + }, + mockUserResp: map[string]any{ + "roles": []string{"manager", "reviewer"}, + "department": "sales", + "region": "europe", + "level": "manager", + "team_members": 15, + "budget_limit": 50000.75, + "preferences": map[string]any{ + "theme": "dark", + "language": "en", + "timezone": "UTC+1", + "notifications": map[string]any{ + "email": true, + "sms": false, + "push": true, + }, + }, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "roles": []string{"manager", "reviewer"}, + "department": "sales", + "region": "europe", + "level": "manager", + "team_members": 15, + "budget_limit": 50000.75, + "preferences": map[string]any{ + "theme": "dark", + "language": "en", + "timezone": "UTC+1", + "notifications": map[string]any{ + "email": true, + "sms": false, + "push": true, + }, + }, + }, + }, + shouldCallMock: true, + expectedUserID: userID, + }, + + "should return user attributes with empty roles array": { + request: &ip.GetInfoRequest{ + Params: userID.String(), + }, + mockUserResp: map[string]any{ + "roles": []string{}, + "department": "hr", + "level": "junior", + "active": true, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "roles": []string{}, + "department": "hr", + "level": "junior", + "active": true, + }, + }, + shouldCallMock: true, + expectedUserID: userID, + }, + + "should return empty attributes when user has no attributes": { + request: &ip.GetInfoRequest{ + Params: userID.String(), + }, + mockUserResp: map[string]any{ + "roles": []string{}, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "roles": []string{}, + }, + }, + shouldCallMock: true, + expectedUserID: userID, + }, + + "should return error when request is nil": { + request: nil, + expectedError: "request cannot be nil", + shouldCallMock: false, + }, + + "should return error when params is not string": { + request: &ip.GetInfoRequest{ + Params: 12345, + }, + expectedError: "user ID parameter must be a string, got int: 12345", + shouldCallMock: false, + }, + + "should return error when params is not valid UUID": { + request: &ip.GetInfoRequest{ + Params: "invalid-uuid", + }, + expectedError: "user ID must be a valid UUID format, got: invalid-uuid", + shouldCallMock: false, + }, + + "should return error when repository returns NotFoundError": { + request: &ip.GetInfoRequest{ + Params: anotherUserID.String(), + }, + mockUserErr: &repository.NotFoundError{ + Resource: "user", + Key: "id", + Value: anotherUserID.String(), + }, + expectedError: fmt.Sprintf("user with id %s not found", anotherUserID.String()), + shouldCallMock: true, + expectedUserID: anotherUserID, + }, + + "should return error when repository returns database error": { + request: &ip.GetInfoRequest{ + Params: userID.String(), + }, + mockUserErr: errors.New("database connection failed"), + expectedError: "database connection failed", + shouldCallMock: true, + expectedUserID: userID, + }, + + "should handle different UUID formats": { + request: &ip.GetInfoRequest{ + Params: userID.String(), + }, + mockUserResp: map[string]any{ + "roles": []string{"user"}, + "department": "support", + "active": true, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "roles": []string{"user"}, + "department": "support", + "active": true, + }, + }, + shouldCallMock: true, + expectedUserID: userID, + }, + + "should handle params as interface{} containing string": { + request: &ip.GetInfoRequest{ + Params: any(userID.String()), + }, + mockUserResp: map[string]any{ + "roles": []string{"test"}, + "department": "testing", + "temporary": true, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "roles": []string{"test"}, + "department": "testing", + "temporary": true, + }, + }, + shouldCallMock: true, + expectedUserID: userID, + }, + + "should handle user with mixed data types": { + request: &ip.GetInfoRequest{ + Params: userID.String(), + }, + mockUserResp: map[string]any{ + "roles": []string{"developer", "reviewer"}, + "employee_id": 12345, + "salary": 75000.50, + "is_remote": true, + "start_date": "2023-01-15", + "skills": []string{"Go", "JavaScript", "Python", "Docker"}, + "certifications": []any{"AWS Solutions Architect", 2023, true}, + "project_stats": map[string]any{ + "projects_completed": 15, + "avg_rating": 4.8, + "languages_used": []string{"Go", "TypeScript", "SQL"}, + }, + }, + expectedResponse: &ip.GetInfoResponse{ + Info: map[string]any{ + "roles": []string{"developer", "reviewer"}, + "employee_id": 12345, + "salary": 75000.50, + "is_remote": true, + "start_date": "2023-01-15", + "skills": []string{"Go", "JavaScript", "Python", "Docker"}, + "certifications": []any{"AWS Solutions Architect", 2023, true}, + "project_stats": map[string]any{ + "projects_completed": 15, + "avg_rating": 4.8, + "languages_used": []string{"Go", "TypeScript", "SQL"}, + }, + }, + }, + shouldCallMock: true, + expectedUserID: userID, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Create mock repository + mockRepo := new(mockUserAttributesRepository) + + // Setup mocks in test case loop + if tc.shouldCallMock { + mockRepo.On("GetUserAttributesByID", mock.Anything, tc.expectedUserID).Return( + tc.mockUserResp, + tc.mockUserErr, + ) + } + + provider := NewUserProvider(mockRepo) + + // Execute + response, err := provider.GetInfo(context.Background(), tc.request) + + // Assert + if tc.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, response) + } else { + assert.NoError(t, err) + assert.EqualValues(t, tc.expectedResponse, response) + } + + // Verify mock expectations + mockRepo.AssertExpectations(t) + }) + } +} diff --git a/examples/abac/internal/obligation/auditlog_handler.go b/examples/abac/internal/obligation/auditlog_handler.go new file mode 100644 index 0000000..efcb514 --- /dev/null +++ b/examples/abac/internal/obligation/auditlog_handler.go @@ -0,0 +1,76 @@ +package obligation + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "strings" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" +) + +// AuditLogHandler logs messages based on PDP obligations +type AuditLogHandler struct { + logger *slog.Logger +} + +// AuditLogAttributes defines the fixed structure of audit log obligations +type AuditLogAttributes struct { + Level string `json:"level"` // Log level (DEBUG, INFO, WARN, ERROR) + Message string `json:"message"` // Message to be logged +} + +// NewAuditLogHandler creates a new audit log handler +func NewAuditLogHandler(logger *slog.Logger) *AuditLogHandler { + return &AuditLogHandler{logger: logger} +} + +// Handle processes audit logging obligations from PDP +func (h *AuditLogHandler) Handle(ctx context.Context, obligation ro.Obligation, _ http.ResponseWriter, _ *http.Request) error { + var attrs AuditLogAttributes + if err := parseAttributes(obligation.Attributes, &attrs); err != nil { + return fmt.Errorf("invalid audit log attributes: %w", err) + } + + h.logger.LogAttrs(ctx, parseLogLevel(attrs.Level), attrs.Message, + slog.String("obligation_id", obligation.ID)) + + return nil +} + +// parseAttributes converts and validates obligation attributes +func parseAttributes(attrs map[string]any, result *AuditLogAttributes) error { + data, err := json.Marshal(attrs) + if err != nil { + return fmt.Errorf("failed to marshal attributes: %w", err) + } + + if err := json.Unmarshal(data, result); err != nil { + return fmt.Errorf("failed to unmarshal attributes: %w", err) + } + + if result.Level == "" { + return fmt.Errorf("level is required") + } + if result.Message == "" { + return fmt.Errorf("message is required") + } + + return nil +} + +// parseLogLevel converts string level to slog.Level +func parseLogLevel(level string) slog.Level { + switch strings.ToUpper(level) { + case "DEBUG": + return slog.LevelDebug + case "WARN": + return slog.LevelWarn + case "ERROR": + return slog.LevelError + default: + return slog.LevelInfo + } +} diff --git a/examples/abac/internal/obligation/auditlog_handler_test.go b/examples/abac/internal/obligation/auditlog_handler_test.go new file mode 100644 index 0000000..335df56 --- /dev/null +++ b/examples/abac/internal/obligation/auditlog_handler_test.go @@ -0,0 +1,116 @@ +package obligation + +import ( + "context" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/stretchr/testify/assert" +) + +type testLogHandler struct { + messages []string + levels []slog.Level +} + +func (h *testLogHandler) Handle(_ context.Context, r slog.Record) error { //nolint:gocritic // slog.Handler interface + h.messages = append(h.messages, r.Message) + h.levels = append(h.levels, r.Level) + return nil +} + +func (*testLogHandler) Enabled(_ context.Context, _ slog.Level) bool { return true } +func (h *testLogHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } +func (h *testLogHandler) WithGroup(_ string) slog.Handler { return h } +func (h *testLogHandler) reset() { h.messages = nil; h.levels = nil } + +func TestAuditLogHandler_Handle(t *testing.T) { + testCases := map[string]struct { + obligation ro.Obligation + expectedError string + expectedMessage string + expectedLogLevel slog.Level + }{ + "should log message with ERROR level": { + obligation: ro.Obligation{ + ID: "audit_log", + Attributes: map[string]any{ + "level": "ERROR", + "message": "access denied", + }, + }, + expectedMessage: "access denied", + expectedLogLevel: slog.LevelError, + }, + "should log message with INFO level when level is invalid": { + obligation: ro.Obligation{ + ID: "audit_log", + Attributes: map[string]any{ + "level": "INVALID", + "message": "test message", + }, + }, + expectedMessage: "test message", + expectedLogLevel: slog.LevelInfo, + }, + "should return error when level is missing": { + obligation: ro.Obligation{ + ID: "audit_log", + Attributes: map[string]any{ + "message": "test message", + }, + }, + expectedError: "level is required", + }, + "should return error when message is missing": { + obligation: ro.Obligation{ + ID: "audit_log", + Attributes: map[string]any{ + "level": "ERROR", + }, + }, + expectedError: "message is required", + }, + "should return error for invalid attributes type": { + obligation: ro.Obligation{ + ID: "audit_log", + Attributes: map[string]any{ + "level": 123, + "message": 456, + }, + }, + expectedError: "failed to unmarshal attributes", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Initialize handler and test logger + testLogger := &testLogHandler{} + handler := NewAuditLogHandler(slog.New(testLogger)) + + // Execute + err := handler.Handle( + context.Background(), + tc.obligation, + httptest.NewRecorder(), + httptest.NewRequest(http.MethodGet, "/test", http.NoBody), + ) + + // Assert + if tc.expectedError != "" { + assert.ErrorContains(t, err, tc.expectedError) + assert.Empty(t, testLogger.messages) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedMessage, testLogger.messages[0]) + assert.Equal(t, tc.expectedLogLevel, testLogger.levels[0]) + } + + testLogger.reset() + }) + } +} diff --git a/examples/abac/internal/policyresolver/default_resolver.go b/examples/abac/internal/policyresolver/default_resolver.go new file mode 100644 index 0000000..ae16dca --- /dev/null +++ b/examples/abac/internal/policyresolver/default_resolver.go @@ -0,0 +1,33 @@ +package policyresolver + +import ( + "context" + "errors" + + "github.com/CameronXie/access-control-explorer/abac/decisionmaker" +) + +type defaultResolver struct { + policyID string + policyVersion string +} + +// NewDefaultResolver creates a new instance of defaultResolver with the specified policy ID and version. +func NewDefaultResolver(policyID, policyVersion string) decisionmaker.PolicyResolver { + return &defaultResolver{ + policyID: policyID, + policyVersion: policyVersion, + } +} + +// Resolve resolves policy references by returning the configured default policy. +func (r *defaultResolver) Resolve(_ context.Context, req *decisionmaker.DecisionRequest) ([]decisionmaker.PolicyIdReference, error) { + if req == nil { + return nil, errors.New("decision request cannot be nil") + } + + return []decisionmaker.PolicyIdReference{{ + ID: r.policyID, + Version: r.policyVersion, + }}, nil +} diff --git a/examples/abac/internal/policyresolver/default_resolver_test.go b/examples/abac/internal/policyresolver/default_resolver_test.go new file mode 100644 index 0000000..72045ed --- /dev/null +++ b/examples/abac/internal/policyresolver/default_resolver_test.go @@ -0,0 +1,106 @@ +package policyresolver + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + + "github.com/CameronXie/access-control-explorer/abac/decisionmaker" +) + +func TestDefaultProcessor_Process(t *testing.T) { + defaultPolicyID := "default-deny-policy" + defaultPolicyVersion := "v1" + + testCases := map[string]struct { + req *decisionmaker.DecisionRequest + setupContext func() context.Context + expectedResult []decisionmaker.PolicyIdReference + expectedError string + }{ + "should return default policy for basic request": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user123", + Attributes: map[string]any{ + "role": "manager", + }, + }, + Action: decisionmaker.Action{ + ID: "read", + }, + Resource: decisionmaker.Resource{ + ID: "document1", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []decisionmaker.PolicyIdReference{ + { + ID: defaultPolicyID, + Version: defaultPolicyVersion, + }, + }, + }, + + "should return error when request is nil": { + req: nil, + setupContext: func() context.Context { return context.Background() }, + expectedResult: nil, + expectedError: "decision request cannot be nil", + }, + + "should return default policy even with cancelled context": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user123", + Attributes: map[string]any{ + "role": "admin", + }, + }, + Resource: decisionmaker.Resource{ + ID: "system", + }, + Action: decisionmaker.Action{ + ID: "manage", + }, + }, + setupContext: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + expectedResult: []decisionmaker.PolicyIdReference{ + { + ID: defaultPolicyID, + Version: defaultPolicyVersion, + }, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + resolver := NewDefaultResolver(defaultPolicyID, defaultPolicyVersion) + + // Setup context + ctx := tc.setupContext() + + // Execute + result, err := resolver.Resolve(ctx, tc.req) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.expectedResult, result) + }) + } +} diff --git a/examples/abac/internal/policyresolver/rbac_resolver.go b/examples/abac/internal/policyresolver/rbac_resolver.go new file mode 100644 index 0000000..b465a52 --- /dev/null +++ b/examples/abac/internal/policyresolver/rbac_resolver.go @@ -0,0 +1,45 @@ +package policyresolver + +import ( + "context" + "errors" + + "github.com/CameronXie/access-control-explorer/abac/decisionmaker" +) + +type rbacResolver struct { + policyID string + policyVersion string +} + +// NewRBACResolver creates and returns a new instance of a PolicyResolver for handling RBAC policy resolution. +func NewRBACResolver(policyID, policyVersion string) decisionmaker.PolicyResolver { + return &rbacResolver{ + policyID: policyID, + policyVersion: policyVersion, + } +} + +// Resolve checks if the subject has a "role" attribute and returns RBAC policy reference if found. +func (r *rbacResolver) Resolve(_ context.Context, req *decisionmaker.DecisionRequest) ([]decisionmaker.PolicyIdReference, error) { + if req == nil { + return nil, errors.New("decision request cannot be nil") + } + + policyIdRefs := make([]decisionmaker.PolicyIdReference, 0) + + // Check if subject has role attribute + if req.Subject.Attributes == nil { + return policyIdRefs, nil + } + + if _, hasRoles := req.Subject.Attributes["roles"]; !hasRoles { + return policyIdRefs, nil + } + + // Return RBAC policy reference + return []decisionmaker.PolicyIdReference{{ + ID: r.policyID, + Version: r.policyVersion, + }}, nil +} diff --git a/examples/abac/internal/policyresolver/rbac_resolver_test.go b/examples/abac/internal/policyresolver/rbac_resolver_test.go new file mode 100644 index 0000000..6db300c --- /dev/null +++ b/examples/abac/internal/policyresolver/rbac_resolver_test.go @@ -0,0 +1,239 @@ +package policyresolver + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + + "github.com/CameronXie/access-control-explorer/abac/decisionmaker" +) + +func TestRoleProcessor_Process(t *testing.T) { + rbacPolicyID := "rbac-policy" + rbacPolicyVersion := "v1" + + testCases := map[string]struct { + req *decisionmaker.DecisionRequest + setupContext func() context.Context + expectedResult []decisionmaker.PolicyIdReference + expectedError string + }{ + "should return RBAC policy when subject has role attribute": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user123", + Attributes: map[string]any{ + "roles": "manager", + }, + }, + Action: decisionmaker.Action{ + ID: "read", + }, + Resource: decisionmaker.Resource{ + ID: "document1", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []decisionmaker.PolicyIdReference{ + { + ID: rbacPolicyID, + Version: rbacPolicyVersion, + }, + }, + }, + + "should return RBAC policy when subject has role attribute with other attributes": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user456", + Attributes: map[string]any{ + "roles": "employee", + "department": "engineering", + "level": 5, + }, + }, + Resource: decisionmaker.Resource{ + ID: "system", + }, + Action: decisionmaker.Action{ + ID: "write", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []decisionmaker.PolicyIdReference{ + { + ID: rbacPolicyID, + Version: rbacPolicyVersion, + }, + }, + }, + + "should return empty slice when subject has no role attribute": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user789", + Attributes: map[string]any{ + "department": "sales", + "level": 3, + }, + }, + Action: decisionmaker.Action{ + ID: "read", + }, + Resource: decisionmaker.Resource{ + ID: "report1", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []decisionmaker.PolicyIdReference{}, + }, + + "should return empty slice when subject has no attributes": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user999", + }, + Resource: decisionmaker.Resource{ + ID: "document2", + }, + Action: decisionmaker.Action{ + ID: "delete", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []decisionmaker.PolicyIdReference{}, + }, + + "should return empty slice when subject attributes is nil": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user000", + Attributes: nil, + }, + Resource: decisionmaker.Resource{ + ID: "resource1", + }, + Action: decisionmaker.Action{ + ID: "execute", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []decisionmaker.PolicyIdReference{}, + }, + + "should return empty slice when request is nil": { + req: nil, + setupContext: func() context.Context { return context.Background() }, + expectedResult: nil, + expectedError: "decision request cannot be nil", + }, + + "should handle context cancellation": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user123", + Attributes: map[string]any{ + "roles": "admin", + }, + }, + Resource: decisionmaker.Resource{ + ID: "system", + }, + Action: decisionmaker.Action{ + ID: "manage", + }, + }, + setupContext: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + expectedResult: []decisionmaker.PolicyIdReference{ + { + ID: rbacPolicyID, + Version: rbacPolicyVersion, + }, + }, + }, + + "should return RBAC policy when role attribute is empty string": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user456", + Attributes: map[string]any{ + "roles": "", + }, + }, + Resource: decisionmaker.Resource{ + ID: "document3", + }, + Action: decisionmaker.Action{ + ID: "read", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []decisionmaker.PolicyIdReference{ + { + ID: rbacPolicyID, + Version: rbacPolicyVersion, + }, + }, + }, + + "should return RBAC policy when role attribute is nil": { + req: &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: "user789", + Attributes: map[string]any{ + "roles": nil, + }, + }, + Resource: decisionmaker.Resource{ + ID: "file1", + }, + Action: decisionmaker.Action{ + ID: "upload", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: []decisionmaker.PolicyIdReference{ + { + ID: rbacPolicyID, + Version: rbacPolicyVersion, + }, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + resolver := NewRBACResolver(rbacPolicyID, rbacPolicyVersion) + + // Setup context + ctx := tc.setupContext() + + // Execute + result, err := resolver.Resolve(ctx, tc.req) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.expectedResult, result) + }) + } +} diff --git a/examples/abac/internal/repository/errors.go b/examples/abac/internal/repository/errors.go new file mode 100644 index 0000000..5468c54 --- /dev/null +++ b/examples/abac/internal/repository/errors.go @@ -0,0 +1,17 @@ +package repository + +import ( + "fmt" +) + +// NotFoundError represents an error when a resource is not found +type NotFoundError struct { + Resource string + Key string + Value string +} + +// Error implements the error interface +func (e *NotFoundError) Error() string { + return fmt.Sprintf("%s with %s %s not found", e.Resource, e.Key, e.Value) +} diff --git a/examples/abac/internal/repository/errors_test.go b/examples/abac/internal/repository/errors_test.go new file mode 100644 index 0000000..9faba79 --- /dev/null +++ b/examples/abac/internal/repository/errors_test.go @@ -0,0 +1,30 @@ +package repository + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNotFoundError_Error(t *testing.T) { + testCases := map[string]struct { + err *NotFoundError + expected string + }{ + "should format error message with all fields": { + err: &NotFoundError{ + Resource: "user", + Key: "email", + Value: "john.doe@example.com", + }, + expected: "user with email john.doe@example.com not found", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result := tc.err.Error() + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/examples/abac/internal/repository/postgres/order_repository.go b/examples/abac/internal/repository/postgres/order_repository.go new file mode 100644 index 0000000..fd63829 --- /dev/null +++ b/examples/abac/internal/repository/postgres/order_repository.go @@ -0,0 +1,92 @@ +package postgres + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/domain" +) + +const ( + OrderResource = "order" +) + +// OrderRepository provides database operations for orders +type OrderRepository struct { + pool *pgxpool.Pool +} + +// NewOrderRepository creates a new OrderRepository instance +func NewOrderRepository(pool *pgxpool.Pool) *OrderRepository { + return &OrderRepository{ + pool: pool, + } +} + +// CreateOrder creates a new order in the database +func (r *OrderRepository) CreateOrder(ctx context.Context, order *domain.Order) error { + query := "INSERT INTO orders (id, name, attributes) VALUES ($1, $2, $3)" + + _, err := r.pool.Exec(ctx, query, order.ID, order.Name, order.Attributes) + if err != nil { + return fmt.Errorf("failed to create order: %w", err) + } + + return nil +} + +// GetOrderByID retrieves an order by its ID from the database +func (r *OrderRepository) GetOrderByID(ctx context.Context, id uuid.UUID) (*domain.Order, error) { + var order domain.Order + query := "SELECT id, name, attributes FROM orders WHERE id = $1" + + err := r.pool.QueryRow(ctx, query, id).Scan(&order.ID, &order.Name, &order.Attributes) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, &repository.NotFoundError{ + Resource: OrderResource, + Key: "id", + Value: id.String(), + } + } + return nil, fmt.Errorf("failed to retrieve order with id %s: %w", id, err) + } + + return &order, nil +} + +// GetOrderAttributesByID retrieves order attributes by order ID. +// Returns the attributes as a map for use by info providers. +func (r *OrderRepository) GetOrderAttributesByID(ctx context.Context, id uuid.UUID) (map[string]any, error) { + var attributesData []byte + query := "SELECT attributes FROM orders WHERE id = $1" + + err := r.pool.QueryRow(ctx, query, id).Scan(&attributesData) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, &repository.NotFoundError{ + Resource: OrderResource, + Key: "id", + Value: id.String(), + } + } + return nil, fmt.Errorf("query attributes for order %s: %w", id, err) + } + + // Decode JSON attributes + attrs := make(map[string]any) + if len(attributesData) > 0 { + if err := json.Unmarshal(attributesData, &attrs); err != nil { + return nil, fmt.Errorf("decode attributes for order %s: %w", id, err) + } + } + + return attrs, nil +} diff --git a/examples/abac/internal/repository/postgres/order_repository_test.go b/examples/abac/internal/repository/postgres/order_repository_test.go new file mode 100644 index 0000000..bcbf6be --- /dev/null +++ b/examples/abac/internal/repository/postgres/order_repository_test.go @@ -0,0 +1,570 @@ +//nolint:dupl // unit tests +package postgres + +import ( + "context" + "errors" + "fmt" + "os" + "testing" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/domain" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testOrder struct { + id uuid.UUID + name string + attributes map[string]any +} + +func TestOrderRepository_CreateOrder(t *testing.T) { + pool := setupTestDBForOrders(t) + defer pool.Close() + + orderID := uuid.New() + testCases := map[string]struct { + order *domain.Order + setupContext func() context.Context + expectedError string + verifyInDB bool + }{ + + "should create order with empty attributes": { + order: &domain.Order{ + ID: orderID, + Name: "Basic Order", + Attributes: map[string]any{}, + }, + setupContext: func() context.Context { return context.Background() }, + verifyInDB: true, + }, + + "should create order with complex nested attributes": { + order: &domain.Order{ + ID: orderID, + Name: "Enterprise Solution", + Attributes: map[string]any{ + "category": "enterprise", + "price": 1999.50, + "currency": "EUR", + "status": "pending", + "quantity": float64(5), + "discount_rate": 0.15, + "metadata": map[string]any{ + "source": "web", + "campaign": "summer2024", + "referrer": "partner_site", + "custom_fields": map[string]any{ + "priority": "high", + "rush_order": true, + "special_notes": "Handle with care", + }, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + verifyInDB: true, + }, + + "should return error when context is cancelled": { + order: &domain.Order{ + ID: orderID, + Name: "Context Test Order", + Attributes: map[string]any{ + "category": "test", + "price": 49.99, + }, + }, + setupContext: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + expectedError: "context canceled", + verifyInDB: false, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + repo := NewOrderRepository(pool) + + // Setup context + ctx := tc.setupContext() + + // Execute + err := repo.CreateOrder(ctx, tc.order) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + } else { + assert.NoError(t, err) + } + + // Verify in database if expected + if tc.verifyInDB && err == nil { + retrievedOrder, err := repo.GetOrderByID(context.Background(), tc.order.ID) + assert.NoError(t, err) + assert.EqualValues(t, tc.order, retrievedOrder) + } + + // Clean up test data + cleanupTestOrdersData(t, pool) + }) + } +} + +func TestOrderRepository_GetOrderByID(t *testing.T) { + pool := setupTestDBForOrders(t) + defer pool.Close() + + orderID := uuid.New() + anotherOrderID := uuid.New() + nonExistentID := uuid.New() + + testCases := map[string]struct { + id uuid.UUID + testOrders []testOrder + expectedResult *domain.Order + expectedError string + expectNotFoundErr bool + }{ + "should return order with complex attributes": { + id: orderID, + testOrders: []testOrder{ + { + id: orderID, + name: "Enterprise Solution", + attributes: map[string]any{ + "category": "enterprise", + "price": 1999.50, + "currency": "EUR", + "status": "pending", + "quantity": 5, + "discount_rate": 0.15, + "metadata": map[string]any{ + "source": "web", + "campaign": "summer2024", + "referrer": "partner_site", + "custom_fields": map[string]any{ + "priority": "high", + "rush_order": true, + "special_notes": "Handle with care", + }, + }, + }, + }, + }, + expectedResult: &domain.Order{ + ID: orderID, + Name: "Enterprise Solution", + Attributes: map[string]any{ + "category": "enterprise", + "price": 1999.50, + "currency": "EUR", + "status": "pending", + "quantity": float64(5), + "discount_rate": 0.15, + "metadata": map[string]any{ + "source": "web", + "campaign": "summer2024", + "referrer": "partner_site", + "custom_fields": map[string]any{ + "priority": "high", + "rush_order": true, + "special_notes": "Handle with care", + }, + }, + }, + }, + }, + + "should return order with empty attributes": { + id: orderID, + testOrders: []testOrder{ + { + id: orderID, + name: "Basic Order", + }, + }, + expectedResult: &domain.Order{ + ID: orderID, + Name: "Basic Order", + Attributes: map[string]any{}, + }, + }, + + "should return NotFoundError when order does not exist": { + id: nonExistentID, + testOrders: []testOrder{ + { + id: orderID, + name: "Existing Order", + attributes: map[string]any{ + "category": "standard", + "price": 99.99, + }, + }, + }, + expectedResult: nil, + expectedError: fmt.Sprintf("order with id %s not found", nonExistentID.String()), + expectNotFoundErr: true, + }, + + "should handle multiple orders but return correct one": { + id: orderID, + testOrders: []testOrder{ + { + id: anotherOrderID, + name: "Other Order", + attributes: map[string]any{ + "category": "other", + "price": 199.99, + }, + }, + { + id: orderID, + name: "Target Order", + attributes: map[string]any{ + "category": "target", + "price": 399.99, + "priority": "high", + }, + }, + }, + expectedResult: &domain.Order{ + ID: orderID, + Name: "Target Order", + Attributes: map[string]any{ + "category": "target", + "price": 399.99, + "priority": "high", + }, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + repo := NewOrderRepository(pool) + + // Setup test data + setupTestOrdersData(t, pool, tc.testOrders) + + // Execute + result, err := repo.GetOrderByID(context.Background(), tc.id) + + // Assert + if tc.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + + // Test custom NotFoundError type + if tc.expectNotFoundErr { + var notFoundErr *repository.NotFoundError + assert.True(t, errors.As(err, ¬FoundErr)) + assert.Equal(t, OrderResource, notFoundErr.Resource) + assert.Equal(t, "id", notFoundErr.Key) + assert.Equal(t, tc.id.String(), notFoundErr.Value) + + // Test errors.Is functionality + var notFoundError *repository.NotFoundError + assert.True(t, errors.As(err, ¬FoundError)) + } + } else { + assert.NoError(t, err) + assert.EqualValues(t, tc.expectedResult, result) + } + + // Clean up test data + cleanupTestOrdersData(t, pool) + }) + } +} + +func TestOrderRepository_GetOrderAttributesByID(t *testing.T) { + pool := setupTestDBForOrders(t) + defer pool.Close() + + orderID := uuid.New() + anotherOrderID := uuid.New() + nonExistentID := uuid.New() + + testCases := map[string]struct { + id uuid.UUID + testOrders []testOrder + setupContext func() context.Context + expectedResult map[string]any + expectedError string + expectNotFoundErr bool + }{ + "should return attributes when order exists with ID": { + id: orderID, + testOrders: []testOrder{ + { + id: orderID, + name: "Complex Order", + attributes: map[string]any{ + "category": "premium", + "price": 2499.99, + "currency": "USD", + "total_items": 15, + "customer": map[string]any{ + "id": "cust_123", + "name": "John Doe", + "tier": "gold", + "contacts": []string{"email", "sms", "push"}, + }, + "metadata": map[string]any{ + "source": "mobile_app", + "campaign": "holiday2024", + "processing": true, + "tags": []string{"urgent", "vip", "express"}, + "analytics": map[string]any{ + "conversion_rate": 0.85, + "session_id": "sess_789", + "utm_source": "google", + }, + }, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "category": "premium", + "price": 2499.99, + "currency": "USD", + "total_items": float64(15), + "customer": map[string]any{ + "id": "cust_123", + "name": "John Doe", + "tier": "gold", + "contacts": []any{"email", "sms", "push"}, + }, + "metadata": map[string]any{ + "source": "mobile_app", + "campaign": "holiday2024", + "processing": true, + "tags": []any{"urgent", "vip", "express"}, + "analytics": map[string]any{ + "conversion_rate": 0.85, + "session_id": "sess_789", + "utm_source": "google", + }, + }, + }, + }, + + "should return empty map when order has no attributes": { + id: orderID, + testOrders: []testOrder{ + { + id: orderID, + name: "Basic Order", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{}, + }, + + "should handle mixed data types correctly": { + id: orderID, + testOrders: []testOrder{ + { + id: orderID, + name: "Mixed Data Order", + attributes: map[string]any{ + "string_field": "hello world", + "number_field": 42, + "float_field": 3.14159, + "bool_field": true, + "null_field": nil, + "array_numbers": []int{1, 2, 3, 4, 5}, + "array_mixed": []any{"text", 123, false, 99.9}, + "nested": map[string]any{ + "level1": map[string]any{ + "level2": "deep value", + "array": []string{"a", "b", "c"}, + }, + }, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "string_field": "hello world", + "number_field": float64(42), + "float_field": 3.14159, + "bool_field": true, + "null_field": nil, + "array_numbers": []any{float64(1), float64(2), float64(3), float64(4), float64(5)}, + "array_mixed": []any{"text", float64(123), false, 99.9}, + "nested": map[string]any{ + "level1": map[string]any{ + "level2": "deep value", + "array": []any{"a", "b", "c"}, + }, + }, + }, + }, + + "should return NotFoundError when order does not exist": { + id: nonExistentID, + testOrders: []testOrder{ + { + id: orderID, + name: "Existing Order", + attributes: map[string]any{ + "category": "standard", + "price": 99.99, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: nil, + expectedError: fmt.Sprintf("order with id %s not found", nonExistentID.String()), + expectNotFoundErr: true, + }, + + "should return error when context is cancelled": { + id: orderID, + testOrders: []testOrder{ + { + id: orderID, + name: "Context Test Order", + attributes: map[string]any{ + "category": "test", + "price": 49.99, + }, + }, + }, + setupContext: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + expectedResult: nil, + expectedError: "context canceled", + }, + + "should handle multiple orders but return correct attributes": { + id: orderID, + testOrders: []testOrder{ + { + id: anotherOrderID, + name: "Other Order", + attributes: map[string]any{ + "category": "other", + "price": 199.99, + }, + }, + { + id: orderID, + name: "Target Order", + attributes: map[string]any{ + "category": "target", + "price": 399.99, + "priority": "high", + "user_id": "target_user", + "processing": true, + "tags": []string{"express", "priority"}, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "category": "target", + "price": 399.99, + "priority": "high", + "user_id": "target_user", + "processing": true, + "tags": []any{"express", "priority"}, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + repo := NewOrderRepository(pool) + + // Setup test data + setupTestOrdersData(t, pool, tc.testOrders) + + // Setup context + ctx := tc.setupContext() + + // Execute + result, err := repo.GetOrderAttributesByID(ctx, tc.id) + + // Assert + if tc.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + + // Test custom NotFoundError type + if tc.expectNotFoundErr { + var notFoundErr *repository.NotFoundError + assert.True(t, errors.As(err, ¬FoundErr)) + assert.Equal(t, OrderResource, notFoundErr.Resource) + assert.Equal(t, "id", notFoundErr.Key) + assert.Equal(t, tc.id.String(), notFoundErr.Value) + } + } else { + assert.NoError(t, err) + assert.EqualValues(t, tc.expectedResult, result) + } + + // Clean up test data + cleanupTestOrdersData(t, pool) + }) + } +} + +func setupTestDBForOrders(t *testing.T) *pgxpool.Pool { + pg := fmt.Sprintf( + "postgres://%s:%s@%s/%s?sslmode=%s", + os.Getenv("POSTGRES_USER"), + os.Getenv("POSTGRES_PASSWORD"), + os.Getenv("POSTGRES_HOST"), + os.Getenv("POSTGRES_DB_TEST"), + os.Getenv("POSTGRES_SSL"), + ) + + pool, err := pgxpool.New(context.Background(), pg) + require.NoError(t, err) + return pool +} + +func setupTestOrdersData(t *testing.T, pool *pgxpool.Pool, orders []testOrder) { + for _, order := range orders { + if order.attributes == nil { + _, err := pool.Exec( + context.Background(), + "INSERT INTO orders (id, name) VALUES ($1, $2)", + order.id, order.name, + ) + require.NoError(t, err) + continue + } + + _, err := pool.Exec( + context.Background(), + "INSERT INTO orders (id, name, attributes) VALUES ($1, $2, $3)", + order.id, order.name, order.attributes, + ) + require.NoError(t, err) + } +} + +func cleanupTestOrdersData(t *testing.T, pool *pgxpool.Pool) { + _, err := pool.Exec(context.Background(), "TRUNCATE TABLE orders") + require.NoError(t, err) +} diff --git a/examples/abac/internal/repository/postgres/rbac_repository.go b/examples/abac/internal/repository/postgres/rbac_repository.go new file mode 100644 index 0000000..d177159 --- /dev/null +++ b/examples/abac/internal/repository/postgres/rbac_repository.go @@ -0,0 +1,152 @@ +package postgres + +import ( + "context" + "fmt" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/infoprovider" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// RBACRepository provides Postgres-backed RBAC reads. +type RBACRepository struct { + pool *pgxpool.Pool +} + +// NewRBACRepository constructs a Postgres RBAC repository. +func NewRBACRepository(pool *pgxpool.Pool) *RBACRepository { + return &RBACRepository{pool: pool} +} + +// GetRoleDescendants retrieves all descendant roles (including the roots) via a recursive CTE. +func (r *RBACRepository) GetRoleDescendants(ctx context.Context, rootRoles []string) ([]string, error) { + if len(rootRoles) == 0 { + return nil, fmt.Errorf("rootRoles cannot be empty") + } + + const query = ` +WITH RECURSIVE role_descendants AS ( + -- Start with all root roles provided + SELECT r.id, + r.name, + 0 AS level + FROM roles r + WHERE r.name = ANY($1::text[]) + + UNION ALL + + -- Add direct children + SELECT child_role.id, + child_role.name, + rd.level + 1 + FROM role_hierarchy rh + INNER JOIN roles child_role ON rh.child_role_id = child_role.id + INNER JOIN role_descendants rd ON rh.parent_role_id = rd.id +) +SELECT DISTINCT name +FROM role_descendants +ORDER BY name; +` + rows, err := r.pool.Query(ctx, query, rootRoles) + if err != nil { + return nil, fmt.Errorf("query role descendants for %v: %w", rootRoles, err) + } + defer rows.Close() + + var out []string + var roleName string + _, scanErr := pgx.ForEachRow(rows, []any{&roleName}, func() error { + out = append(out, roleName) + return nil + }) + if scanErr != nil { + return nil, fmt.Errorf("scan role descendants for %v: %w", rootRoles, scanErr) + } + return out, nil +} + +// GetPermissionsByRoles retrieves permissions grouped by role for the given role names. +func (r *RBACRepository) GetPermissionsByRoles(ctx context.Context, roles []string) (map[string][]infoprovider.Permission, error) { + if len(roles) == 0 { + return map[string][]infoprovider.Permission{}, nil + } + + const query = ` +SELECT + r.name AS role_name, + a.name AS action_name, + res.name AS resource_name, + rp.id AS permission_id, + rpc.attribute_key, + rpc.operator, + rpc.attribute_value +FROM role_permissions rp + INNER JOIN roles r ON rp.role_id = r.id + INNER JOIN actions a ON rp.action_id = a.id + INNER JOIN resources res ON rp.resource_id = res.id + LEFT JOIN role_permission_conditions rpc ON rp.id = rpc.permission_id +WHERE r.name = ANY($1) +ORDER BY r.name, a.name, res.name, rpc.attribute_key +` + rows, err := r.pool.Query(ctx, query, roles) + if err != nil { + return nil, fmt.Errorf("query permissions for roles %v: %w", roles, err) + } + defer rows.Close() + + return processPermissionRows(rows, roles) +} + +// processPermissionRows converts DB rows into a role->[]Permission map. +func processPermissionRows(rows pgx.Rows, roles []string) (map[string][]infoprovider.Permission, error) { + var roleName, actionName, resourceName, permissionID string + var attributeKey, operator *string + var attributeValue any + + type agg struct { + role string + perm infoprovider.Permission + } + + byID := make(map[string]*agg) + + _, err := pgx.ForEachRow( + rows, + []any{&roleName, &actionName, &resourceName, &permissionID, &attributeKey, &operator, &attributeValue}, + func() error { + entry, ok := byID[permissionID] + if !ok { + entry = &agg{ + role: roleName, + perm: infoprovider.Permission{ + ActionName: actionName, + ResourceName: resourceName, + Conditions: make([]infoprovider.PermissionCondition, 0), + }, + } + byID[permissionID] = entry + } + + // Append condition row if present + if attributeKey != nil && operator != nil { + entry.perm.Conditions = append(entry.perm.Conditions, infoprovider.PermissionCondition{ + AttributeKey: *attributeKey, + Operator: *operator, + AttributeValue: attributeValue, + }) + } + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("process permissions for roles %v: %w", roles, err) + } + + // Group by role + out := make(map[string][]infoprovider.Permission) + for _, entry := range byID { + out[entry.role] = append(out[entry.role], entry.perm) + } + return out, nil +} diff --git a/examples/abac/internal/repository/postgres/rbac_repository_test.go b/examples/abac/internal/repository/postgres/rbac_repository_test.go new file mode 100644 index 0000000..8b7af9e --- /dev/null +++ b/examples/abac/internal/repository/postgres/rbac_repository_test.go @@ -0,0 +1,371 @@ +package postgres + +import ( + "context" + "encoding/json" + "fmt" + "os" + "testing" + + ip "github.com/CameronXie/access-control-explorer/examples/abac/internal/infoprovider" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testRole struct { + name string +} + +type testAction struct { + name string +} + +type testResource struct { + name string +} + +type testRoleHierarchy struct { + parentRole string + childRole string +} + +type testRolePermission struct { + roleName string + actionName string + resourceName string + conditions []testPermissionCondition +} + +type testPermissionCondition struct { + attributeKey string + operator string + attributeValue any +} + +func TestRBACRepository_GetRoleDescendants(t *testing.T) { + pool := setupTestDBForRBACRepo(t) + defer pool.Close() + repo := NewRBACRepository(pool) + + testCases := map[string]struct { + rootRoles []string + testRoles []testRole + testRoleHierarchy []testRoleHierarchy + setupCtx func() context.Context + expectedDescendants []string + expectedErrSubstr string + }{ + "should return descendants for single role with child": { + rootRoles: []string{"manager"}, + testRoles: []testRole{ + {name: "manager"}, + {name: "employee"}, + }, + testRoleHierarchy: []testRoleHierarchy{ + {parentRole: "manager", childRole: "employee"}, + }, + setupCtx: func() context.Context { return context.Background() }, + expectedDescendants: []string{"employee", "manager"}, + }, + "should return deduped descendants for multiple roots and deep tree": { + rootRoles: []string{"admin", "lead"}, + testRoles: []testRole{ + {name: "admin"}, + {name: "lead"}, + {name: "engineer"}, + }, + testRoleHierarchy: []testRoleHierarchy{ + {parentRole: "admin", childRole: "lead"}, + {parentRole: "lead", childRole: "engineer"}, + }, + setupCtx: func() context.Context { return context.Background() }, + expectedDescendants: []string{"admin", "engineer", "lead"}, + }, + "should return empty when role does not exist": { + rootRoles: []string{"ghost"}, + testRoles: []testRole{}, + testRoleHierarchy: []testRoleHierarchy{}, + setupCtx: func() context.Context { return context.Background() }, + expectedDescendants: []string{}, + }, + "should return error when root roles is empty": { + rootRoles: []string{}, + testRoles: []testRole{}, + testRoleHierarchy: []testRoleHierarchy{}, + setupCtx: func() context.Context { return context.Background() }, + expectedErrSubstr: "rootRoles cannot be empty", + }, + "should return error when context is cancelled": { + rootRoles: []string{"manager"}, + testRoles: []testRole{ + {name: "manager"}, + }, + testRoleHierarchy: []testRoleHierarchy{}, + setupCtx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + expectedErrSubstr: "query role descendants", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + setupTestRBACRepoData(t, pool, tc.testRoles, nil, nil, tc.testRoleHierarchy, nil) + + ctx := tc.setupCtx() + + got, err := repo.GetRoleDescendants(ctx, tc.rootRoles) + + if tc.expectedErrSubstr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrSubstr) + assert.Nil(t, got) + } else { + require.NoError(t, err) + assert.ElementsMatch(t, tc.expectedDescendants, got) + } + + cleanupTestRBACRepoData(t, pool) + }) + } +} + +func TestRBACRepository_GetPermissionsByRoles(t *testing.T) { + pool := setupTestDBForRBACRepo(t) + defer pool.Close() + repo := NewRBACRepository(pool) + + testCases := map[string]struct { + roles []string + testRoles []testRole + testActions []testAction + testResources []testResource + testPermissions []testRolePermission + setupCtx func() context.Context + expected map[string][]ip.Permission + expectedErrSubstr string + }{ + "should return permissions with and without conditions": { + roles: []string{"manager", "employee"}, + testRoles: []testRole{ + {name: "manager"}, + {name: "employee"}, + }, + testActions: []testAction{ + {name: "read"}, + {name: "write"}, + }, + testResources: []testResource{ + {name: "document"}, + {name: "report"}, + }, + testPermissions: []testRolePermission{ + { + roleName: "manager", + actionName: "read", + resourceName: "document", + conditions: []testPermissionCondition{ + {attributeKey: "department", operator: "eq", attributeValue: "sales"}, + }, + }, + { + roleName: "manager", + actionName: "write", + resourceName: "document", + conditions: []testPermissionCondition{ + {attributeKey: "level", operator: "gte", attributeValue: float64(5)}, + }, + }, + { + roleName: "employee", + actionName: "read", + resourceName: "report", + conditions: []testPermissionCondition{}, + }, + }, + setupCtx: func() context.Context { return context.Background() }, + expected: map[string][]ip.Permission{ + "manager": { + { + ActionName: "read", + ResourceName: "document", + Conditions: []ip.PermissionCondition{ + {AttributeKey: "department", Operator: "eq", AttributeValue: "sales"}, + }, + }, + { + ActionName: "write", + ResourceName: "document", + Conditions: []ip.PermissionCondition{ + {AttributeKey: "level", Operator: "gte", AttributeValue: float64(5)}, + }, + }, + }, + "employee": { + { + ActionName: "read", + ResourceName: "report", + Conditions: []ip.PermissionCondition{}, + }, + }, + }, + }, + "should return empty when role has no permissions": { + roles: []string{"viewer"}, + testRoles: []testRole{ + {name: "viewer"}, + }, + setupCtx: func() context.Context { return context.Background() }, + expected: map[string][]ip.Permission{}, + }, + "should return empty when input roles is empty": { + roles: []string{}, + testRoles: []testRole{}, + setupCtx: func() context.Context { return context.Background() }, + expected: map[string][]ip.Permission{}, + }, + "should return error when context is cancelled": { + roles: []string{"manager"}, + testRoles: []testRole{ + {name: "manager"}, + }, + testActions: []testAction{{name: "read"}}, + testResources: []testResource{{name: "doc"}}, + testPermissions: []testRolePermission{ + {roleName: "manager", actionName: "read", resourceName: "doc"}, + }, + setupCtx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + expectedErrSubstr: "query permissions for roles", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + setupTestRBACRepoData(t, pool, tc.testRoles, tc.testActions, tc.testResources, nil, tc.testPermissions) + + ctx := tc.setupCtx() + got, err := repo.GetPermissionsByRoles(ctx, tc.roles) + + if tc.expectedErrSubstr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrSubstr) + assert.Nil(t, got) + } else { + require.NoError(t, err) + // Compare role keys and permissions (order-independent) + assert.Equal(t, len(tc.expected), len(got)) + for role, expectedPerms := range tc.expected { + actual := got[role] + assert.ElementsMatch(t, expectedPerms, actual, "permissions for role %s should match", role) + } + } + + cleanupTestRBACRepoData(t, pool) + }) + } +} + +func setupTestDBForRBACRepo(t *testing.T) *pgxpool.Pool { + pg := fmt.Sprintf( + "postgres://%s:%s@%s/%s?sslmode=%s", + os.Getenv("POSTGRES_USER"), + os.Getenv("POSTGRES_PASSWORD"), + os.Getenv("POSTGRES_HOST"), + os.Getenv("POSTGRES_DB_TEST"), + os.Getenv("POSTGRES_SSL"), + ) + + pool, err := pgxpool.New(context.Background(), pg) + require.NoError(t, err) + + return pool +} + +func setupTestRBACRepoData( + t *testing.T, + pool *pgxpool.Pool, + roles []testRole, + actions []testAction, + resources []testResource, + hierarchies []testRoleHierarchy, + permissions []testRolePermission, +) { + // Insert roles + for _, role := range roles { + _, err := pool.Exec(context.Background(), "INSERT INTO roles (id, name) VALUES ($1, $2)", uuid.New(), role.name) + require.NoError(t, err) + } + + // Insert actions + for _, action := range actions { + _, err := pool.Exec(context.Background(), "INSERT INTO actions (id, name) VALUES ($1, $2)", uuid.New(), action.name) + require.NoError(t, err) + } + + // Insert resources + for _, resource := range resources { + _, err := pool.Exec(context.Background(), "INSERT INTO resources (id, name) VALUES ($1, $2)", uuid.New(), resource.name) + require.NoError(t, err) + } + + // Insert role hierarchy + for _, hierarchy := range hierarchies { + _, err := pool.Exec(context.Background(), ` + INSERT INTO role_hierarchy (id, parent_role_id, child_role_id) + SELECT $1, pr.id, cr.id + FROM roles pr, roles cr + WHERE pr.name = $2 AND cr.name = $3 + `, uuid.New(), hierarchy.parentRole, hierarchy.childRole) + require.NoError(t, err) + } + + // Insert role permissions and their conditions + for _, permission := range permissions { + permissionID := uuid.New() + + // Insert permission + _, err := pool.Exec(context.Background(), ` + INSERT INTO role_permissions (id, role_id, action_id, resource_id) + SELECT $1, r.id, a.id, res.id + FROM roles r, actions a, resources res + WHERE r.name = $2 AND a.name = $3 AND res.name = $4 + `, permissionID, permission.roleName, permission.actionName, permission.resourceName) + require.NoError(t, err) + + // Insert permission conditions + for _, condition := range permission.conditions { + jsonValue, err := json.Marshal(condition.attributeValue) + require.NoError(t, err) + + _, err = pool.Exec(context.Background(), ` + INSERT INTO role_permission_conditions (permission_id, attribute_key, operator, attribute_value) + VALUES ($1, $2, $3, $4) + `, permissionID, condition.attributeKey, condition.operator, jsonValue) + require.NoError(t, err) + } + } +} + +func cleanupTestRBACRepoData(t *testing.T, pool *pgxpool.Pool) { + tables := []string{ + "role_permission_conditions", + "role_permissions", + "role_hierarchy", + "resources", + "actions", + "roles", + } + + for _, table := range tables { + _, err := pool.Exec(context.Background(), fmt.Sprintf("TRUNCATE TABLE %s CASCADE", table)) + require.NoError(t, err) + } +} diff --git a/examples/abac/internal/repository/postgres/user_repository.go b/examples/abac/internal/repository/postgres/user_repository.go new file mode 100644 index 0000000..876d952 --- /dev/null +++ b/examples/abac/internal/repository/postgres/user_repository.go @@ -0,0 +1,101 @@ +package postgres + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +const ( + UserResource = "user" +) + +type UserRepository struct { + pool *pgxpool.Pool +} + +func NewUserRepository(pool *pgxpool.Pool) *UserRepository { + return &UserRepository{pool: pool} +} + +// GetUserIDByEmail retrieves a user ID by email address. +func (r *UserRepository) GetUserIDByEmail(ctx context.Context, email string) (uuid.UUID, error) { + if email == "" { + return uuid.Nil, fmt.Errorf("email cannot be empty") + } + + const query = `SELECT id FROM users WHERE email = $1` + + var userID uuid.UUID + err := r.pool.QueryRow(ctx, query, email).Scan(&userID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return uuid.Nil, &repository.NotFoundError{ + Resource: UserResource, + Key: "email", + Value: email, + } + } + return uuid.Nil, fmt.Errorf("query user ID by email %s: %w", email, err) + } + + return userID, nil +} + +// GetUserAttributesByID retrieves user attributes by user ID. +// Returns attributes where roles are guaranteed to be []string type. +// This method implements the UserAttributesRepository interface. +func (r *UserRepository) GetUserAttributesByID(ctx context.Context, id uuid.UUID) (map[string]any, error) { + type row struct { + Attributes []byte `db:"attributes"` + Roles []string `db:"roles"` + } + + // Extract roles as string array, handling different JSON formats + const query = ` +SELECT + u.attributes, + CASE jsonb_typeof(u.attributes->'roles') + WHEN 'array' THEN ARRAY(SELECT jsonb_array_elements_text(u.attributes->'roles')) + WHEN 'string' THEN ARRAY[(u.attributes->>'roles')] + ELSE ARRAY[]::text[] + END AS roles +FROM users u +WHERE u.id = $1 +` + rows, err := r.pool.Query(ctx, query, id) + if err != nil { + return nil, fmt.Errorf("query attributes for user %s: %w", id, err) + } + defer rows.Close() + + rec, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[row]) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, &repository.NotFoundError{ + Resource: UserResource, + Key: "id", + Value: id.String(), + } + } + return nil, fmt.Errorf("scan attributes for user %s: %w", id, err) + } + + // Decode JSON attributes + attrs := make(map[string]any) + if len(rec.Attributes) > 0 { + if err := json.Unmarshal(rec.Attributes, &attrs); err != nil { + return nil, fmt.Errorf("decode attributes for user %s: %w", id, err) + } + } + // Override roles with string array from database query + attrs["roles"] = rec.Roles + + return attrs, nil +} diff --git a/examples/abac/internal/repository/postgres/user_repository_test.go b/examples/abac/internal/repository/postgres/user_repository_test.go new file mode 100644 index 0000000..1f73cb4 --- /dev/null +++ b/examples/abac/internal/repository/postgres/user_repository_test.go @@ -0,0 +1,548 @@ +//nolint:dupl // unit tests +package postgres + +import ( + "context" + "errors" + "fmt" + "os" + "testing" + + "github.com/CameronXie/access-control-explorer/examples/abac/internal/repository" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testUser struct { + id uuid.UUID + email string + attributes map[string]any +} + +func TestUserRepository_GetUserIDByEmail(t *testing.T) { + pool := setupTestDBForUsers(t) + defer pool.Close() + + userID := uuid.New() + anotherUserID := uuid.New() + nonExistentEmail := "nonexistent@example.com" + + testCases := map[string]struct { + email string + testUsers []testUser + setupContext func() context.Context + expectedResult uuid.UUID + expectedError string + expectNotFoundErr bool + }{ + "should return user ID when user exists with email": { + email: "john.doe@example.com", + testUsers: []testUser{ + { + id: userID, + email: "john.doe@example.com", + attributes: map[string]any{ + "roles": []string{"admin", "user"}, + "department": "engineering", + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: userID, + }, + + "should return correct user ID when user exists": { + email: "jane.smith@example.com", + testUsers: []testUser{ + { + id: userID, + email: "jane.smith@example.com", + attributes: map[string]any{ + "roles": []string{"manager"}, + "department": "sales", + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: userID, + }, + + "should return user ID even with empty attributes": { + email: "empty.user@example.com", + testUsers: []testUser{ + { + id: userID, + email: "empty.user@example.com", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: userID, + }, + + "should return NotFoundError when user does not exist": { + email: nonExistentEmail, + testUsers: []testUser{ + { + id: userID, + email: "existing.user@example.com", + attributes: map[string]any{ + "roles": []string{"user"}, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: uuid.Nil, + expectedError: fmt.Sprintf("user with email %s not found", nonExistentEmail), + expectNotFoundErr: true, + }, + + "should return error when email is empty": { + email: "", + testUsers: []testUser{ + { + id: userID, + email: "test.user@example.com", + attributes: map[string]any{ + "roles": []string{"user"}, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: uuid.Nil, + expectedError: "email cannot be empty", + }, + + "should return error when context is cancelled": { + email: "context.test@example.com", + testUsers: []testUser{ + { + id: userID, + email: "context.test@example.com", + attributes: map[string]any{ + "roles": []string{"user"}, + }, + }, + }, + setupContext: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + expectedResult: uuid.Nil, + expectedError: "context canceled", + }, + + "should handle multiple users but return correct ID": { + email: "target.user@example.com", + testUsers: []testUser{ + { + id: anotherUserID, + email: "other.user@example.com", + attributes: map[string]any{ + "roles": []string{"guest"}, + }, + }, + { + id: userID, + email: "target.user@example.com", + attributes: map[string]any{ + "roles": []string{"admin"}, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: userID, + }, + + "should return ID regardless of attribute complexity": { + email: "complex.user@example.com", + testUsers: []testUser{ + { + id: userID, + email: "complex.user@example.com", + attributes: map[string]any{ + "roles": []string{"admin", "user", "moderator"}, + "department": "engineering", + "team_members": 15, + "budget_limit": 50000.75, + "preferences": map[string]any{ + "theme": "dark", + "language": "en", + "notifications": map[string]any{ + "email": true, + "sms": false, + }, + }, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: userID, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + repo := NewUserRepository(pool) + + // Setup test data + setupTestUsersData(t, pool, tc.testUsers) + + // Setup context + ctx := tc.setupContext() + + // Execute + result, err := repo.GetUserIDByEmail(ctx, tc.email) + + // Assert + if tc.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Equal(t, uuid.Nil, result) + + // Test custom NotFoundError type + if tc.expectNotFoundErr { + var notFoundErr *repository.NotFoundError + assert.True(t, errors.As(err, ¬FoundErr)) + assert.Equal(t, UserResource, notFoundErr.Resource) + assert.Equal(t, "email", notFoundErr.Key) + assert.Equal(t, tc.email, notFoundErr.Value) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedResult, result) + } + + // Clean up test data + cleanupTestUsersData(t, pool) + }) + } +} + +func TestUserRepository_GetUserAttributesByID(t *testing.T) { + pool := setupTestDBForUsers(t) + defer pool.Close() + + userID := uuid.New() + anotherUserID := uuid.New() + nonExistentID := uuid.New() + + testCases := map[string]struct { + id uuid.UUID + testUsers []testUser + setupContext func() context.Context + expectedResult map[string]any + expectedError string + expectNotFoundErr bool + }{ + "should return attributes when user exists with ID": { + id: userID, + testUsers: []testUser{ + { + id: userID, + email: "john.doe@example.com", + attributes: map[string]any{ + "roles": []string{"admin", "user"}, + "department": "engineering", + "region": "north_america", + "level": "senior", + "permissions": map[string]any{ + "read": true, + "write": true, + "admin": true, + }, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "roles": []string{"admin", "user"}, + "department": "engineering", + "region": "north_america", + "level": "senior", + "permissions": map[string]any{ + "read": true, + "write": true, + "admin": true, + }, + }, + }, + + "should return attributes with complex nested data": { + id: userID, + testUsers: []testUser{ + { + id: userID, + email: "jane.smith@example.com", + attributes: map[string]any{ + "roles": []string{"manager", "reviewer"}, + "department": "sales", + "region": "europe", + "level": "manager", + "team_members": 15, + "budget_limit": 50000.75, + "preferences": map[string]any{ + "theme": "dark", + "language": "en", + "timezone": "UTC+1", + "notifications": map[string]any{ + "email": true, + "sms": false, + "push": true, + }, + }, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "roles": []string{"manager", "reviewer"}, + "department": "sales", + "region": "europe", + "level": "manager", + "team_members": float64(15), + "budget_limit": 50000.75, + "preferences": map[string]any{ + "theme": "dark", + "language": "en", + "timezone": "UTC+1", + "notifications": map[string]any{ + "email": true, + "sms": false, + "push": true, + }, + }, + }, + }, + + "should return empty attributes map when user has no attributes": { + id: userID, + testUsers: []testUser{ + { + id: userID, + email: "empty.user@example.com", + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "roles": []string{}, + }, + }, + + "should handle roles as single string value": { + id: userID, + testUsers: []testUser{ + { + id: userID, + email: "single.role@example.com", + attributes: map[string]any{ + "roles": "admin", // Single string instead of array + "department": "it", + "active": true, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "roles": []string{"admin"}, + "department": "it", + "active": true, + }, + }, + + "should handle missing roles field gracefully": { + id: userID, + testUsers: []testUser{ + { + id: userID, + email: "no.roles@example.com", + attributes: map[string]any{ + "department": "hr", + "level": "junior", + "active": true, + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "roles": []string{}, + "department": "hr", + "level": "junior", + "active": true, + }, + }, + + "should return NotFoundError when user does not exist": { + id: nonExistentID, + testUsers: []testUser{ + { + id: userID, + email: "existing.user@example.com", + attributes: map[string]any{ + "roles": []string{"user"}, + "department": "support", + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: nil, + expectedError: fmt.Sprintf("user with id %s not found", nonExistentID.String()), + expectNotFoundErr: true, + }, + + "should return error when context is cancelled": { + id: userID, + testUsers: []testUser{ + { + id: userID, + email: "context.test@example.com", + attributes: map[string]any{ + "roles": []string{"user"}, + "department": "testing", + }, + }, + }, + setupContext: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + expectedResult: nil, + expectedError: "context canceled", + }, + + "should handle multiple users but return correct attributes": { + id: userID, + testUsers: []testUser{ + { + id: anotherUserID, + email: "other.user@example.com", + attributes: map[string]any{ + "roles": []string{"guest"}, + "department": "other", + }, + }, + { + id: userID, + email: "target.user@example.com", + attributes: map[string]any{ + "roles": []string{"admin"}, + "department": "target", + "priority": "high", + "clearance": "level5", + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "roles": []string{"admin"}, + "department": "target", + "priority": "high", + "clearance": "level5", + }, + }, + + "should handle arrays with mixed content correctly": { + id: userID, + testUsers: []testUser{ + { + id: userID, + email: "mixed.array@example.com", + attributes: map[string]any{ + "roles": []string{"admin", "user", "moderator"}, + "numbers": []int{1, 2, 3}, + "mixed_values": []any{"string", 42, true, 3.14}, + "department": "mixed", + }, + }, + }, + setupContext: func() context.Context { return context.Background() }, + expectedResult: map[string]any{ + "roles": []string{"admin", "user", "moderator"}, + "numbers": []any{float64(1), float64(2), float64(3)}, + "mixed_values": []any{"string", float64(42), true, 3.14}, + "department": "mixed", + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + repo := NewUserRepository(pool) + + // Setup test data + setupTestUsersData(t, pool, tc.testUsers) + + // Setup context + ctx := tc.setupContext() + + // Execute + result, err := repo.GetUserAttributesByID(ctx, tc.id) + + // Assert + if tc.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + + // Test custom NotFoundError type + if tc.expectNotFoundErr { + var notFoundErr *repository.NotFoundError + assert.True(t, errors.As(err, ¬FoundErr)) + assert.Equal(t, UserResource, notFoundErr.Resource) + assert.Equal(t, "id", notFoundErr.Key) + assert.Equal(t, tc.id.String(), notFoundErr.Value) + } + } else { + assert.NoError(t, err) + assert.EqualValues(t, tc.expectedResult, result) + } + + // Clean up test data + cleanupTestUsersData(t, pool) + }) + } +} + +func setupTestDBForUsers(t *testing.T) *pgxpool.Pool { + pg := fmt.Sprintf( + "postgres://%s:%s@%s/%s?sslmode=%s", + os.Getenv("POSTGRES_USER"), + os.Getenv("POSTGRES_PASSWORD"), + os.Getenv("POSTGRES_HOST"), + os.Getenv("POSTGRES_DB_TEST"), + os.Getenv("POSTGRES_SSL"), + ) + + pool, err := pgxpool.New(context.Background(), pg) + require.NoError(t, err) + return pool +} + +func setupTestUsersData(t *testing.T, pool *pgxpool.Pool, users []testUser) { + for _, user := range users { + if user.attributes == nil { + _, err := pool.Exec( + context.Background(), + "INSERT INTO users (id, email) VALUES ($1, $2)", + user.id, user.email, + ) + require.NoError(t, err) + continue + } + + _, err := pool.Exec( + context.Background(), + "INSERT INTO users (id, email, attributes) VALUES ($1, $2, $3)", + user.id, user.email, user.attributes, + ) + require.NoError(t, err) + } +} + +func cleanupTestUsersData(t *testing.T, pool *pgxpool.Pool) { + _, err := pool.Exec(context.Background(), "TRUNCATE TABLE users") + require.NoError(t, err) +} diff --git a/examples/abac/internal/requestorchestrator/infoanalyser/rbac_analyser.go b/examples/abac/internal/requestorchestrator/infoanalyser/rbac_analyser.go new file mode 100644 index 0000000..3ff332b --- /dev/null +++ b/examples/abac/internal/requestorchestrator/infoanalyser/rbac_analyser.go @@ -0,0 +1,40 @@ +package infoanalyser + +import ( + "context" + "fmt" + + "github.com/CameronXie/access-control-explorer/abac/infoprovider" + ip "github.com/CameronXie/access-control-explorer/examples/abac/internal/infoprovider" + "github.com/CameronXie/access-control-explorer/examples/abac/internal/requestorchestrator" +) + +type rbacAnalyser struct { + infoType ip.InfoType +} + +func NewRBACAnalyser(infoType ip.InfoType) requestorchestrator.InfoAnalyser { + return &rbacAnalyser{ + infoType: infoType, + } +} + +func (a *rbacAnalyser) AnalyseInfoRequirements( + _ context.Context, + req *requestorchestrator.EnrichedAccessRequest, +) ([]infoprovider.GetInfoRequest, error) { + if req == nil { + return nil, fmt.Errorf("request cannot be nil") + } + + if req.Subject.Attributes == nil || req.Subject.Attributes["roles"] == nil { + return nil, nil + } + + return []infoprovider.GetInfoRequest{ + { + InfoType: string(a.infoType), + Params: req.Subject.Attributes["roles"], + }, + }, nil +} diff --git a/examples/abac/internal/requestorchestrator/requestorchestrator.go b/examples/abac/internal/requestorchestrator/requestorchestrator.go new file mode 100644 index 0000000..09a2e9f --- /dev/null +++ b/examples/abac/internal/requestorchestrator/requestorchestrator.go @@ -0,0 +1,248 @@ +package requestorchestrator + +import ( + "context" + "fmt" + "sync" + + "github.com/CameronXie/access-control-explorer/abac/decisionmaker" + "github.com/CameronXie/access-control-explorer/abac/infoprovider" + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/google/uuid" + "golang.org/x/sync/errgroup" +) + +type Subject struct { + ro.Subject + Attributes map[string]any `json:"attributes,omitempty"` +} + +type Resource struct { + ro.Resource + Attributes map[string]any `json:"attributes,omitempty"` +} + +type EnrichedAccessRequest struct { + Subject Subject + Action ro.Action + Resource Resource +} + +type InfoAnalyser interface { + AnalyseInfoRequirements(ctx context.Context, req *EnrichedAccessRequest) ([]infoprovider.GetInfoRequest, error) +} + +type requestOrchestrator struct { + infoAnalysers []InfoAnalyser + infoProvider infoprovider.InfoProvider + decisionMaker decisionmaker.DecisionMaker +} + +func NewRequestOrchestrator( + infoAnalysers []InfoAnalyser, + infoProvider infoprovider.InfoProvider, + decisionMaker decisionmaker.DecisionMaker, +) ro.RequestOrchestrator { + return &requestOrchestrator{ + infoAnalysers: infoAnalysers, + infoProvider: infoProvider, + decisionMaker: decisionMaker, + } +} + +// EvaluateAccess processes an access request through enrichment, analysis, and decision-making +func (o *requestOrchestrator) EvaluateAccess(ctx context.Context, req *ro.AccessRequest) (*ro.AccessResponse, error) { + enrichedReq, err := o.enrichAccessRequest(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to enrich request: %w", err) + } + + infoReqs, err := o.AnalyseInfoRequirements(ctx, enrichedReq) + if err != nil { + return nil, fmt.Errorf("failed to analyze requirements: %w", err) + } + + additionalInfo, err := o.getAdditionalInfo(ctx, infoReqs) + if err != nil { + return nil, fmt.Errorf("failed to get additional info: %w", err) + } + + resp, err := o.decisionMaker.MakeDecision(ctx, createDecisionRequest(enrichedReq, additionalInfo)) + if err != nil { + return nil, fmt.Errorf("failed to make decision: %w", err) + } + + return toAccessResponse(resp), nil +} + +// enrichAccessRequest fetches basic subject and resource attributes in parallel +func (o *requestOrchestrator) enrichAccessRequest(ctx context.Context, req *ro.AccessRequest) (*EnrichedAccessRequest, error) { + enrichedReq := &EnrichedAccessRequest{ + Subject: Subject{ + Subject: req.Subject, + Attributes: make(map[string]any), + }, + Action: req.Action, + Resource: Resource{ + Resource: req.Resource, + Attributes: make(map[string]any), + }, + } + + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + + // Fetch subject attributes + g.Go(func() error { + resp, err := o.infoProvider.GetInfo(ctx, &infoprovider.GetInfoRequest{ + InfoType: req.Subject.Type, + Params: req.Subject.ID, + }) + + if err != nil { + return fmt.Errorf("failed to get subject info: %w", err) + } + + mu.Lock() + enrichedReq.Subject.Attributes = resp.Info + mu.Unlock() + return nil + }) + + // Fetch resource attributes + g.Go(func() error { + resp, err := o.infoProvider.GetInfo(ctx, &infoprovider.GetInfoRequest{ + InfoType: req.Resource.Type, + Params: req.Resource.ID, + }) + + if err != nil { + return fmt.Errorf("failed to get resource info: %w", err) + } + + mu.Lock() + enrichedReq.Resource.Attributes = resp.Info + mu.Unlock() + return nil + }) + + return enrichedReq, g.Wait() +} + +// AnalyseInfoRequirements collects additional info requirements from all analyzers +func (o *requestOrchestrator) AnalyseInfoRequirements( + ctx context.Context, + req *EnrichedAccessRequest, +) ([]infoprovider.GetInfoRequest, error) { + var results []infoprovider.GetInfoRequest + + for _, analyser := range o.infoAnalysers { + reqs, err := analyser.AnalyseInfoRequirements(ctx, req) + if err != nil { + return nil, fmt.Errorf("analyser failed: %w", err) + } + results = append(results, reqs...) + } + + return results, nil +} + +// getAdditionalInfo fetches additional info in parallel and returns a consolidated result +func (o *requestOrchestrator) getAdditionalInfo(ctx context.Context, infoReqs []infoprovider.GetInfoRequest) (map[string]any, error) { + if len(infoReqs) == 0 { + return make(map[string]any), nil + } + + result := make(map[string]any) + var mu sync.Mutex + g, ctx := errgroup.WithContext(ctx) + + for idx := range infoReqs { + req := infoReqs[idx] + g.Go(func() error { + resp, err := o.infoProvider.GetInfo(ctx, &req) + if err != nil { + return fmt.Errorf("failed to get info for %s: %w", req.Params, err) + } + + mu.Lock() + for k, v := range resp.Info { + if _, ok := result[k]; ok { + return fmt.Errorf("duplicate info for %s", k) + } + + result[k] = v + } + mu.Unlock() + return nil + }) + } + + return result, g.Wait() +} + +// createDecisionRequest converts enriched request to decision request format +func createDecisionRequest(req *EnrichedAccessRequest, additionalInfo map[string]any) *decisionmaker.DecisionRequest { + return &decisionmaker.DecisionRequest{ + RequestID: uuid.New(), + Subject: decisionmaker.Subject{ + ID: req.Subject.ID, + Type: req.Subject.Type, + Attributes: req.Subject.Attributes, + }, + Action: decisionmaker.Action{ + ID: req.Action.ID, + }, + Resource: decisionmaker.Resource{ + ID: req.Resource.ID, + Type: req.Resource.Type, + Attributes: req.Resource.Attributes, + }, + Environment: additionalInfo, + } +} + +// toAccessResponse converts decision response to access response format +func toAccessResponse(resp *decisionmaker.DecisionResponse) *ro.AccessResponse { + result := &ro.AccessResponse{ + RequestID: resp.RequestID, + Decision: ro.Decision(resp.Decision), + Status: ro.Status{ + Code: ro.StatusCode(resp.Status.Code), + Message: resp.Status.Message, + }, + EvaluatedAt: resp.EvaluatedAt, + PolicyIdReferences: make([]ro.PolicyIdReference, 0, len(resp.PolicyIdReferences)), + } + + // Convert obligations if present + if len(resp.Obligations) > 0 { + result.Obligations = make([]ro.Obligation, len(resp.Obligations)) + for i, obligation := range resp.Obligations { + result.Obligations[i] = ro.Obligation{ + ID: obligation.ID, + Attributes: obligation.Attributes, + } + } + } + + // Convert advice if present + if len(resp.Advice) > 0 { + result.Advices = make([]ro.Advice, len(resp.Advice)) + for i, advice := range resp.Advice { + result.Advices[i] = ro.Advice{ + ID: advice.ID, + Attributes: advice.Attributes, + } + } + } + + for _, policyIdReference := range resp.PolicyIdReferences { + result.PolicyIdReferences = append(result.PolicyIdReferences, ro.PolicyIdReference{ + ID: policyIdReference.ID, + Version: policyIdReference.Version, + }) + } + + return result +} diff --git a/examples/abac/internal/requestorchestrator/requestorchestrator_test.go b/examples/abac/internal/requestorchestrator/requestorchestrator_test.go new file mode 100644 index 0000000..070075f --- /dev/null +++ b/examples/abac/internal/requestorchestrator/requestorchestrator_test.go @@ -0,0 +1,273 @@ +package requestorchestrator + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/CameronXie/access-control-explorer/abac/decisionmaker" + "github.com/CameronXie/access-control-explorer/abac/infoprovider" + ro "github.com/CameronXie/access-control-explorer/abac/requestorchestrator" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type mockInfoProvider struct { + mock.Mock +} + +func (m *mockInfoProvider) GetInfo(ctx context.Context, req *infoprovider.GetInfoRequest) (*infoprovider.GetInfoResponse, error) { + args := m.Called(ctx, req) + return args.Get(0).(*infoprovider.GetInfoResponse), args.Error(1) +} + +type mockDecisionMaker struct { + mock.Mock +} + +func (m *mockDecisionMaker) MakeDecision(ctx context.Context, req *decisionmaker.DecisionRequest) (*decisionmaker.DecisionResponse, error) { + args := m.Called(ctx, req) + return args.Get(0).(*decisionmaker.DecisionResponse), args.Error(1) +} + +type mockInfoAnalyser struct { + mock.Mock +} + +func (m *mockInfoAnalyser) AnalyseInfoRequirements(ctx context.Context, req *EnrichedAccessRequest) ([]infoprovider.GetInfoRequest, error) { + args := m.Called(ctx, req) + return args.Get(0).([]infoprovider.GetInfoRequest), args.Error(1) +} + +func TestRequestOrchestrator_EvaluateAccess(t *testing.T) { //nolint:gocyclo // unit test + testCases := map[string]struct { + subjectInfoResp *infoprovider.GetInfoResponse + subjectInfoErr error + resourceInfoResp *infoprovider.GetInfoResponse + resourceInfoErr error + analyserReqs []infoprovider.GetInfoRequest + analyserErr error + additionalInfoResp map[string]*infoprovider.GetInfoResponse + additionalInfoErr map[string]error + decisionResp *decisionmaker.DecisionResponse + decisionErr error + expectedResult *ro.AccessResponse + expectedError string + }{ + "should permit access when all info available": { + subjectInfoResp: &infoprovider.GetInfoResponse{ + Info: map[string]any{"role": "admin"}, + }, + resourceInfoResp: &infoprovider.GetInfoResponse{ + Info: map[string]any{"owner": "user123"}, + }, + analyserReqs: []infoprovider.GetInfoRequest{}, + decisionResp: &decisionmaker.DecisionResponse{ + RequestID: uuid.New(), + Decision: decisionmaker.Permit, + Status: &decisionmaker.Status{Code: decisionmaker.StatusOK, Message: "OK"}, + PolicyIdReferences: []decisionmaker.PolicyIdReference{ + {ID: "policy1", Version: "v1"}, + {ID: "policy2", Version: "v2"}, + }, + EvaluatedAt: time.Now(), + }, + expectedResult: &ro.AccessResponse{ + Decision: ro.Permit, + Status: ro.Status{Code: ro.StatusOK, Message: "OK"}, + PolicyIdReferences: []ro.PolicyIdReference{ + {ID: "policy1", Version: "v1"}, + {ID: "policy2", Version: "v2"}, + }, + }, + }, + + "should return error when subject info not found": { + subjectInfoErr: errors.New("user not found"), + resourceInfoResp: &infoprovider.GetInfoResponse{ + Info: map[string]any{"owner": "user123"}, + }, + expectedError: "failed to enrich request: failed to get subject info: user not found", + }, + + "should return error when resource info not found": { + subjectInfoResp: &infoprovider.GetInfoResponse{ + Info: map[string]any{"role": "admin"}, + }, + resourceInfoErr: errors.New("document not found"), + expectedError: "failed to enrich request: failed to get resource info: document not found", + }, + + "should return error when analyser fails": { + subjectInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + resourceInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + analyserErr: errors.New("analysis failed"), + expectedError: "failed to analyze requirements: analyser failed: analysis failed", + }, + + "should return error when additional info unavailable": { + subjectInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + resourceInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + analyserReqs: []infoprovider.GetInfoRequest{{InfoType: "metadata", Params: "extra"}}, + additionalInfoErr: map[string]error{ + "extra": errors.New("metadata unavailable"), + }, + expectedError: "failed to get additional info: failed to get info for extra: metadata unavailable", + }, + + "should return error when decision maker fails": { + subjectInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + resourceInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + analyserReqs: []infoprovider.GetInfoRequest{}, + decisionErr: errors.New("decision failed"), + expectedError: "failed to make decision: decision failed", + }, + + "should deny access with obligations when policy requires": { + subjectInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + resourceInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + analyserReqs: []infoprovider.GetInfoRequest{}, + decisionResp: &decisionmaker.DecisionResponse{ + RequestID: uuid.New(), + Decision: decisionmaker.Deny, + Status: &decisionmaker.Status{Code: decisionmaker.StatusOK, Message: "Access denied"}, + Obligations: []decisionmaker.Obligation{ + {ID: "log", Attributes: map[string]any{"action": "denied"}}, + }, + Advice: []decisionmaker.Advice{ + {ID: "contact", Attributes: map[string]any{"admin": "true"}}, + }, + EvaluatedAt: time.Now(), + PolicyIdReferences: []decisionmaker.PolicyIdReference{ + {ID: "policy1", Version: "v1"}, + }, + }, + expectedResult: &ro.AccessResponse{ + Decision: ro.Deny, + Status: ro.Status{Code: ro.StatusOK, Message: "Access denied"}, + Obligations: []ro.Obligation{ + {ID: "log", Attributes: map[string]any{"action": "denied"}}, + }, + Advices: []ro.Advice{ + {ID: "contact", Attributes: map[string]any{"admin": "true"}}, + }, + PolicyIdReferences: []ro.PolicyIdReference{ + {ID: "policy1", Version: "v1"}, + }, + }, + }, + + "should return error when duplicate info keys detected": { + subjectInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + resourceInfoResp: &infoprovider.GetInfoResponse{Info: map[string]any{}}, + analyserReqs: []infoprovider.GetInfoRequest{ + {InfoType: "metadata", Params: "extra1"}, + {InfoType: "metadata", Params: "extra2"}, + }, + additionalInfoResp: map[string]*infoprovider.GetInfoResponse{ + "extra1": {Info: map[string]any{"key": "value1"}}, + "extra2": {Info: map[string]any{"key": "value2"}}, + }, + expectedError: "failed to get additional info: duplicate info for key", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Setup fresh mocks for each test + mockInfoProvider := new(mockInfoProvider) + mockDecisionMaker := new(mockDecisionMaker) + mockAnalyser := new(mockInfoAnalyser) + + testRequest := &ro.AccessRequest{ + Subject: ro.Subject{ID: "user123", Type: "user"}, + Action: ro.Action{ID: "read"}, + Resource: ro.Resource{ID: "doc456", Type: "document"}, + } + + // Setup subject info mock + if tc.subjectInfoResp != nil || tc.subjectInfoErr != nil { + mockInfoProvider.On("GetInfo", mock.Anything, &infoprovider.GetInfoRequest{ + InfoType: testRequest.Subject.Type, + Params: testRequest.Subject.ID, + }).Return(tc.subjectInfoResp, tc.subjectInfoErr) + } + + // Setup resource info mock + if tc.resourceInfoResp != nil || tc.resourceInfoErr != nil { + mockInfoProvider.On("GetInfo", mock.Anything, &infoprovider.GetInfoRequest{ + InfoType: testRequest.Resource.Type, + Params: testRequest.Resource.ID, + }).Return(tc.resourceInfoResp, tc.resourceInfoErr) + } + + // Setup analyser mock + if tc.analyserReqs != nil || tc.analyserErr != nil { + mockAnalyser.On("AnalyseInfoRequirements", mock.Anything, mock.Anything).Return( + tc.analyserReqs, tc.analyserErr, + ) + } + + // Setup additional info mocks + if tc.additionalInfoResp != nil || tc.additionalInfoErr != nil { + for id, resp := range tc.additionalInfoResp { + err := tc.additionalInfoErr[id] + mockInfoProvider.On("GetInfo", mock.Anything, &infoprovider.GetInfoRequest{ + InfoType: "metadata", + Params: id, + }).Return(resp, err) + } + for id, err := range tc.additionalInfoErr { + if tc.additionalInfoResp[id] == nil { + mockInfoProvider.On("GetInfo", mock.Anything, &infoprovider.GetInfoRequest{ + InfoType: "metadata", + Params: id, + }).Return((*infoprovider.GetInfoResponse)(nil), err) + } + } + } + + // Setup decision maker mock + if tc.decisionResp != nil || tc.decisionErr != nil { + mockDecisionMaker.On("MakeDecision", mock.Anything, mock.Anything).Return( + tc.decisionResp, tc.decisionErr, + ) + } + + orchestrator := NewRequestOrchestrator( + []InfoAnalyser{mockAnalyser}, + mockInfoProvider, + mockDecisionMaker, + ) + + result, err := orchestrator.EvaluateAccess(context.Background(), testRequest) + + if tc.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + assertAccessResponse(t, tc.expectedResult, result) + } + + mockInfoProvider.AssertExpectations(t) + mockDecisionMaker.AssertExpectations(t) + mockAnalyser.AssertExpectations(t) + }) + } +} + +func assertAccessResponse(t *testing.T, expected, actual *ro.AccessResponse) { + assert.NotNil(t, actual) + assert.NotNil(t, actual.RequestID) + assert.False(t, actual.EvaluatedAt.IsZero()) + + expected.RequestID = actual.RequestID + expected.EvaluatedAt = actual.EvaluatedAt + + assert.EqualValues(t, expected, actual) +} diff --git a/internal/version/version.go b/examples/abac/internal/version/version.go similarity index 100% rename from internal/version/version.go rename to examples/abac/internal/version/version.go diff --git a/internal/keyfetcher/keyfetcher.go b/examples/abac/pkg/keyfetcher/keyfetcher.go similarity index 100% rename from internal/keyfetcher/keyfetcher.go rename to examples/abac/pkg/keyfetcher/keyfetcher.go diff --git a/internal/keyfetcher/keyfetcher_test.go b/examples/abac/pkg/keyfetcher/keyfetcher_test.go similarity index 100% rename from internal/keyfetcher/keyfetcher_test.go rename to examples/abac/pkg/keyfetcher/keyfetcher_test.go diff --git a/examples/abac/pkg/trie/trie.go b/examples/abac/pkg/trie/trie.go new file mode 100644 index 0000000..51845ed --- /dev/null +++ b/examples/abac/pkg/trie/trie.go @@ -0,0 +1,79 @@ +package trie + +import ( + "fmt" +) + +const ( + WildcardSegment = "*" +) + +// Node represents a trie node with generic value type T +type Node[T any] struct { + Children map[string]*Node[T] + Value T + IsEnd bool +} + +// New creates and returns a new instance of Node with an initialized Children map. +func New[T any]() *Node[T] { + return &Node[T]{ + Children: make(map[string]*Node[T]), + } +} + +// Insert adds a value to the trie at the specified paths, creating intermediate nodes if not present. +// Returns an error if paths are empty or the paths already exist in the trie. +func (n *Node[T]) Insert(paths []string, value T) error { + currentNode := n + + // Traverse/create paths in trie + for _, p := range paths { + if currentNode.Children[p] == nil { + currentNode.Children[p] = &Node[T]{ + Children: make(map[string]*Node[T]), + } + } + + currentNode = currentNode.Children[p] + } + + // Check for duplicate paths + if currentNode.IsEnd { + return fmt.Errorf("paths %v already exists", paths) + } + + // Mark as end node and store value + currentNode.Value = value + currentNode.IsEnd = true + return nil +} + +// Search finds a node by paths, supporting wildcard matching +func (n *Node[T]) Search(path []string) (*Node[T], error) { + currentNode := n + + // Traverse paths with wildcard fallback + for _, p := range path { + // Try exact match first + if currentNode.Children[p] != nil { + currentNode = currentNode.Children[p] + continue + } + + // Fall back to wildcard match + if currentNode.Children[WildcardSegment] != nil { + currentNode = currentNode.Children[WildcardSegment] + continue + } + + return nil, fmt.Errorf("no route found for key %s in paths %v", p, path) + } + + // Verify complete paths exists + if !currentNode.IsEnd { + return nil, fmt.Errorf("paths %v not found", path) + } + + return currentNode, nil +} diff --git a/examples/abac/pkg/trie/trie_test.go b/examples/abac/pkg/trie/trie_test.go new file mode 100644 index 0000000..721ed31 --- /dev/null +++ b/examples/abac/pkg/trie/trie_test.go @@ -0,0 +1,213 @@ +package trie + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testTrieEntry[T any] struct { + paths []string + value T +} + +func TestNode_Insert(t *testing.T) { + testCases := map[string]struct { + trieEntries []testTrieEntry[string] + paths []string + value string + expectedError string + }{ + + "should insert multi-segment paths": { + paths: []string{"api", "v1", "users"}, + value: "users-handler", + }, + + "should insert wildcard paths": { + paths: []string{"api", "*", "status"}, + value: "status-handler", + }, + + "should insert multiple different paths": { + trieEntries: []testTrieEntry[string]{ + { + paths: []string{"api", "v1", "users"}, + value: "users-handler", + }, + }, + paths: []string{"api", "v2"}, + value: "v2-handler", + }, + + "should return error for duplicate paths": { + trieEntries: []testTrieEntry[string]{ + { + paths: []string{"api", "users"}, + value: "existing-handler", + }, + }, + paths: []string{"api", "users"}, + value: "new-handler", + expectedError: "paths [api users] already exists", + }, + + "should insert root path successfully": { + paths: []string{}, + value: "new-handler", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Setup + root := New[string]() + if len(tc.trieEntries) > 0 { + for _, entry := range tc.trieEntries { + require.NoError(t, root.Insert(entry.paths, entry.value)) + } + } + + // Execute + err := root.Insert(tc.paths, tc.value) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + return + } + + assert.NoError(t, err) + + // Verify insertion by searching + if len(tc.paths) > 0 { + node, searchErr := root.Search(tc.paths) + assert.NoError(t, searchErr) + assert.Equal(t, tc.value, node.Value) + assert.True(t, node.IsEnd) + } else { + // For empty paths, check root node + assert.Equal(t, tc.value, root.Value) + assert.True(t, root.IsEnd) + } + }) + } +} + +func TestNode_Search(t *testing.T) { + testCases := map[string]struct { + trieEntries []testTrieEntry[string] + paths []string + expectedValue string + expectedError string + }{ + "should find exact match multi-segment": { + trieEntries: []testTrieEntry[string]{ + { + paths: []string{"api", "v1", "users"}, + value: "users-handler", + }, + }, + paths: []string{"api", "v1", "users"}, + expectedValue: "users-handler", + }, + + "should find wildcard match": { + trieEntries: []testTrieEntry[string]{ + { + paths: []string{"api", "*", "status"}, + value: "status-handler", + }, + }, + paths: []string{"api", "v1", "status"}, + expectedValue: "status-handler", + }, + + "should prefer exact match over wildcard": { + trieEntries: []testTrieEntry[string]{ + { + paths: []string{"api", "*", "status"}, + value: "wildcard-handler", + }, + { + paths: []string{"api", "v1", "status"}, + value: "exact-handler", + }, + }, + paths: []string{"api", "v1", "status"}, + expectedValue: "exact-handler", + }, + + "should find root path successfully": { + trieEntries: []testTrieEntry[string]{ + { + paths: []string{}, + value: "root-handler", + }, + }, + paths: []string{}, + expectedValue: "root-handler", + }, + + "should return error for non-existent paths": { + trieEntries: []testTrieEntry[string]{ + { + paths: []string{"api", "v1"}, + value: "v1-handler", + }, + }, + paths: []string{"api", "v2"}, + expectedError: "no route found for key v2 in paths [api v2]", + }, + + "should return error for incomplete paths": { + trieEntries: []testTrieEntry[string]{ + { + paths: []string{"api", "v1", "users"}, + value: "users-handler", + }, + }, + paths: []string{"api", "v1"}, + expectedError: "paths [api v1] not found", + }, + + "should handle mixed exact and wildcard segments": { + trieEntries: []testTrieEntry[string]{ + { + paths: []string{"api", "*", "users", "*"}, + value: "mixed-handler", + }, + }, + paths: []string{"api", "v1", "users", "123"}, + expectedValue: "mixed-handler", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Setup + root := New[string]() + if len(tc.trieEntries) > 0 { + for _, entry := range tc.trieEntries { + require.NoError(t, root.Insert(entry.paths, entry.value)) + } + } + + // Execute + node, err := root.Search(tc.paths) + + // Assert + if tc.expectedError != "" { + assert.Contains(t, err.Error(), tc.expectedError) + assert.Nil(t, node) + return + } + + assert.Nil(t, err) + assert.NotNil(t, node) + assert.Equal(t, tc.expectedValue, node.Value) + assert.True(t, node.IsEnd) + }) + } +} diff --git a/go.mod b/go.mod index 372d6d2..5ea8157 100644 --- a/go.mod +++ b/go.mod @@ -1,22 +1,25 @@ module github.com/CameronXie/access-control-explorer -go 1.23.2 +go 1.24.1 require ( - github.com/casbin/casbin/v2 v2.100.0 - github.com/casbin/gorm-adapter/v3 v3.28.0 - github.com/go-sql-driver/mysql v1.7.0 + github.com/casbin/casbin/v2 v2.103.0 + github.com/casbin/gorm-adapter/v3 v3.32.0 + github.com/go-sql-driver/mysql v1.9.0 github.com/golang-jwt/jwt/v5 v5.2.1 - github.com/open-policy-agent/opa v0.69.0 - github.com/stretchr/testify v1.9.0 + github.com/google/uuid v1.6.0 + github.com/mattn/go-sqlite3 v1.14.15 + github.com/open-policy-agent/opa v1.2.0 + github.com/stretchr/testify v1.10.0 + golang.org/x/sync v0.11.0 ) require ( - github.com/OneOfOne/xxhash v1.2.8 // indirect - github.com/agnivade/levenshtein v1.2.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect + github.com/agnivade/levenshtein v1.2.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect - github.com/casbin/govaluate v1.2.0 // indirect + github.com/casbin/govaluate v1.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -28,44 +31,44 @@ require ( github.com/gobwas/glob v0.2.3 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.4.3 // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-isatty v0.0.17 // indirect github.com/microsoft/go-mssqldb v1.6.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_golang v1.20.4 // indirect + github.com/prometheus/client_golang v1.21.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/common v0.55.0 // indirect + github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/tchap/go-patricia/v2 v2.3.1 // indirect + github.com/tchap/go-patricia/v2 v2.3.2 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/yashtewari/glob-intersection v0.2.0 // indirect - go.opentelemetry.io/otel v1.28.0 // indirect - go.opentelemetry.io/otel/metric v1.28.0 // indirect - go.opentelemetry.io/otel/sdk v1.28.0 // indirect - go.opentelemetry.io/otel/trace v1.28.0 // indirect - golang.org/x/crypto v0.14.0 // indirect - golang.org/x/sys v0.25.0 // indirect - golang.org/x/text v0.18.0 // indirect - google.golang.org/protobuf v1.34.2 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel v1.34.0 // indirect + go.opentelemetry.io/otel/metric v1.34.0 // indirect + go.opentelemetry.io/otel/sdk v1.34.0 // indirect + go.opentelemetry.io/otel/trace v1.34.0 // indirect + golang.org/x/crypto v0.17.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.22.0 // indirect + google.golang.org/protobuf v1.36.3 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gorm.io/driver/mysql v1.5.6 // indirect - gorm.io/driver/postgres v1.5.7 // indirect + gorm.io/driver/mysql v1.5.7 // indirect + gorm.io/driver/postgres v1.5.9 // indirect gorm.io/driver/sqlserver v1.5.3 // indirect - gorm.io/gorm v1.25.8 // indirect - gorm.io/plugin/dbresolver v1.3.0 // indirect + gorm.io/gorm v1.25.12 // indirect + gorm.io/plugin/dbresolver v1.5.3 // indirect modernc.org/libc v1.22.2 // indirect modernc.org/mathutil v1.5.0 // indirect modernc.org/memory v1.5.0 // indirect diff --git a/go.sum b/go.sum index a171322..21c9905 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= @@ -16,10 +18,8 @@ github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0/go.mod h github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 h1:HCc0+LpPfpCKs6LGGLAhwBARt9632unrVcI6i8s/8os= github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= -github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8= -github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q= -github.com/agnivade/levenshtein v1.2.0 h1:U9L4IOT0Y3i0TIlUIDJ7rVUziKi/zPbrJGaFrtYH3SY= -github.com/agnivade/levenshtein v1.2.0/go.mod h1:QVVI16kDrtSuwcpd0p1+xMC6Z/VfhtCyDIjcwga4/DU= +github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= +github.com/agnivade/levenshtein v1.2.1/go.mod h1:QVVI16kDrtSuwcpd0p1+xMC6Z/VfhtCyDIjcwga4/DU= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -28,25 +28,23 @@ github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwN github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA= github.com/bytecodealliance/wasmtime-go/v3 v3.0.2/go.mod h1:RnUjnIXxEJcL6BgCvNyzCCRzZcxCgsZCi+RNlvYor5Q= -github.com/casbin/casbin/v2 v2.100.0 h1:aeugSNjjHfCrgA22nHkVvw2xsscboHv5r0a13ljQKGQ= -github.com/casbin/casbin/v2 v2.100.0/go.mod h1:LO7YPez4dX3LgoTCqSQAleQDo0S0BeZBDxYnPUl95Ng= -github.com/casbin/gorm-adapter/v3 v3.28.0 h1:ORF8prF6SfaipdgT1fud+r1Tp5J0uul8QaKJHqCPY/o= -github.com/casbin/gorm-adapter/v3 v3.28.0/go.mod h1:aftWi0cla0CC1bHQVrSFzBcX/98IFK28AvuPppCQgTs= -github.com/casbin/govaluate v1.2.0 h1:wXCXFmqyY+1RwiKfYo3jMKyrtZmOL3kHwaqDyCPOYak= -github.com/casbin/govaluate v1.2.0/go.mod h1:G/UnbIjZk/0uMNaLwZZmFQrR72tYRZWQkO70si/iR7A= +github.com/casbin/casbin/v2 v2.103.0 h1:dHElatNXNrr8XcseUov0ZSiWjauwmZZE6YMV3eU1yic= +github.com/casbin/casbin/v2 v2.103.0/go.mod h1:Ee33aqGrmES+GNL17L0h9X28wXuo829wnNUnS0edAco= +github.com/casbin/gorm-adapter/v3 v3.32.0 h1:Au+IOILBIE9clox5BJhI2nA3p9t7Ep1ePlupdGbGfus= +github.com/casbin/gorm-adapter/v3 v3.32.0/go.mod h1:Zre/H8p17mpv5U3EaWgPoxLILLdXO3gHW5aoQQpUDZI= +github.com/casbin/govaluate v1.3.0 h1:VA0eSY0M2lA86dYd5kPPuNZMUD9QkWnOCnavGrw9myc= +github.com/casbin/govaluate v1.3.0/go.mod h1:G/UnbIjZk/0uMNaLwZZmFQrR72tYRZWQkO70si/iR7A= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= -github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgraph-io/badger/v3 v3.2103.5 h1:ylPa6qzbjYRQMU6jokoj4wzcaweHylt//CH0AKt0akg= -github.com/dgraph-io/badger/v3 v3.2103.5/go.mod h1:4MPiseMeDQ3FNCYwRbbcBOGJLf5jsE0PPFzRiKjtcdw= -github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= -github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= +github.com/dgraph-io/badger/v4 v4.5.1 h1:7DCIXrQjo1LKmM96YD+hLVJ2EEsyyoWxJfpdd56HLps= +github.com/dgraph-io/badger/v4 v4.5.1/go.mod h1:qn3Be0j3TfV4kPbVoK0arXCD1/nr1ftth6sbL5jxdoA= +github.com/dgraph-io/ristretto/v2 v2.1.0 h1:59LjpOJLNDULHh8MC4UaegN52lC4JnO2dITsie/Pa8I= +github.com/dgraph-io/ristretto/v2 v2.1.0/go.mod h1:uejeqfYXpUomfse0+lO+13ATz4TypQYLJZzBSAemuB4= github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54 h1:SG7nF6SRlWhcT7cNTs5R6Hk4V2lcmLz2NsG2VnInyNo= github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= @@ -70,13 +68,11 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo= +github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= -github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= -github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= @@ -86,21 +82,15 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= -github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY= -github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= -github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/flatbuffers v1.12.1 h1:MVlul7pQNoDzWRLTw5imwYsl+usrS1TXG2H4jg6ImGw= -github.com/google/flatbuffers v1.12.1/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/flatbuffers v24.12.23+incompatible h1:ubBKR94NR4pXUCY/MUsRVzd9umNW7ht7EG9hHfS9FX8= +github.com/google/flatbuffers v24.12.23+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -110,16 +100,18 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3ArSgIyScOAyMRqBxRg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= -github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= @@ -128,11 +120,10 @@ github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -143,6 +134,8 @@ github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/microsoft/go-mssqldb v1.6.0 h1:mM3gYdVwEPFrlg/Dvr2DNVEgYFG7L42l+dGc67NNNpc= github.com/microsoft/go-mssqldb v1.6.0/go.mod h1:00mDtPbeQCRGC1HwOOR5K/gr30P1NcEG0vx6Kbv2aJU= github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM= @@ -151,20 +144,20 @@ github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3P github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/open-policy-agent/opa v0.69.0 h1:s2igLw2Z6IvGWGuXSfugWkVultDMsM9pXiDuMp7ckWw= -github.com/open-policy-agent/opa v0.69.0/go.mod h1:+qyXJGkpEJ6kpB1kGo8JSwHtVXbTdsGdQYPWWNYNj+4= +github.com/open-policy-agent/opa v1.2.0 h1:88NDVCM0of1eO6Z4AFeL3utTEtMuwloFmWWU7dRV1z0= +github.com/open-policy-agent/opa v1.2.0/go.mod h1:30euUmOvuBoebRCcJ7DMF42bRBOPznvt0ACUMYDUGVY= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= -github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA= +github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= -github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= -github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ= @@ -172,8 +165,8 @@ github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqn github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQi3hZP0fzvnsLmzXZdQGc4bEcgu24cp+d4M= github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -189,10 +182,10 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/tchap/go-patricia/v2 v2.3.1 h1:6rQp39lgIYZ+MHmdEq4xzuk1t7OdC35z/xm0BGhTkes= -github.com/tchap/go-patricia/v2 v2.3.1/go.mod h1:VZRHKAb53DLaG+nA9EaYYiaEx6YztwDlLElMsnSHD4k= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tchap/go-patricia/v2 v2.3.2 h1:xTHFutuitO2zqKAQ5rCROYgUb7Or/+IC3fts9/Yc7nM= +github.com/tchap/go-patricia/v2 v2.3.2/go.mod h1:VZRHKAb53DLaG+nA9EaYYiaEx6YztwDlLElMsnSHD4k= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= @@ -202,22 +195,24 @@ github.com/yashtewari/glob-intersection v0.2.0/go.mod h1:LK7pIC3piUjovexikBbJ26Y github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 h1:4K4tsIXefpVJtvA/8srF4V4y0akAoPHkIslgAkjixJA= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0/go.mod h1:jjdQuTGVsXV4vSs+CJ2qYDeDPf9yIJV23qlIzBm73Vg= -go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= -go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 h1:3Q/xZUyC1BBkualc9ROb4G8qkH90LXEIICcs5zv1OYY= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0/go.mod h1:s75jGIWA9OfCMzF0xr+ZgfrB5FEbbV7UuYo32ahUiFI= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0 h1:R3X6ZXmNPRR8ul6i3WgFURCHzaXjHdm0karRG/+dj3s= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0/go.mod h1:QWFXnDavXWwMx2EEcZsf3yxgEKAqsxQ+Syjp+seyInw= -go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= -go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s= -go.opentelemetry.io/otel/sdk v1.28.0 h1:b9d7hIry8yZsgtbmM0DKyPWMMUMlK9NEKuIG4aBqWyE= -go.opentelemetry.io/otel/sdk v1.28.0/go.mod h1:oYj7ClPUA7Iw3m+r7GeEjz0qckQRJK2B8zjcZEfu7Pg= -go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g= -go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI= -go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= -go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 h1:CV7UdSGJt/Ao6Gp4CXckLxVRRsRgDHoI8XjbL3PDl8s= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0/go.mod h1:FRmFuRJfag1IZ2dPkHnEoSFVgTVPUd2qf5Vi69hLb8I= +go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= +go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 h1:OeNbIYk/2C15ckl7glBlOBp5+WlYsOElzTNmiPW/x60= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0/go.mod h1:7Bept48yIeqxP2OZ9/AqIpYS94h2or0aB4FypJTc8ZM= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0 h1:tgJ0uaNS4c98WRNUEx5U3aDlrDOI5Rs+1Vifcw4DJ8U= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0/go.mod h1:U7HYyW0zt/a9x5J1Kjs+r1f/d4ZHnYFclhYY2+YbeoE= +go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= +go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= +go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= +go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= +go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= +go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= +go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4= +go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -225,12 +220,12 @@ golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58 golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -243,13 +238,13 @@ golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -264,8 +259,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -279,50 +274,47 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142 h1:wKguEg1hsxI2/L3hUYrpo1RVi48K+uTyzKqprwLXsb8= -google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142/go.mod h1:d6be+8HhtEtucleCbxpPW9PA9XwISACu8nvpPqF0BVo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= -google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw= -google.golang.org/grpc v1.67.0/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f h1:gap6+3Gk41EItBuyi4XX/bp4oqJ3UwuIMl25yGinuAA= +google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:Ic02D47M+zbarjYYUlK57y316f2MoN0gjAwI3f2S95o= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= +google.golang.org/grpc v1.70.0 h1:pWFv03aZoHzlRKHWicjsZytKAiYCtNS0dHbXnIdq7jQ= +google.golang.org/grpc v1.70.0/go.mod h1:ofIJqVKDXx/JiXrwr2IG4/zwdH9txy3IlF40RmcJSQw= +google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU= +google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/mysql v1.3.2/go.mod h1:ChK6AHbHgDCFZyJp0F+BmVGb06PSIoh9uVYKAlRbb2U= -gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= -gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= -gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= -gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.9 h1:DkegyItji119OlcaLjqN11kHoUgZ/j13E0jkJZgD6A8= +gorm.io/driver/postgres v1.5.9/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= gorm.io/driver/sqlserver v1.5.3 h1:rjupPS4PVw+rjJkfvr8jn2lJ8BMhT4UW5FwuJY0P3Z0= gorm.io/driver/sqlserver v1.5.3/go.mod h1:B+CZ0/7oFJ6tAlefsKoyxdgDCXJKSgwS2bMOQZT0I00= -gorm.io/gorm v1.23.1/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= -gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.8 h1:WAGEZ/aEcznN4D03laj8DKnehe1e9gYQAjW8xyPRdeo= -gorm.io/gorm v1.25.8/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/plugin/dbresolver v1.3.0 h1:uFDX3bIuH9Lhj5LY2oyqR/bU6pqWuDgas35NAPF4X3M= -gorm.io/plugin/dbresolver v1.3.0/go.mod h1:Pr7p5+JFlgDaiM6sOrli5olekJD16YRunMyA2S7ZfKk= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= +gorm.io/plugin/dbresolver v1.5.3 h1:wFwINGZZmttuu9h7XpvbDHd8Lf9bb8GNzp/NpAMV2wU= +gorm.io/plugin/dbresolver v1.5.3/go.mod h1:TSrVhaUg2DZAWP3PrHlDlITEJmNOkL0tFTjvTEsQ4XE= modernc.org/libc v1.22.2 h1:4U7v51GyhlWqQmwCHj28Rdq2Yzwk55ovjFrdPjs8Hb0= modernc.org/libc v1.22.2/go.mod h1:uvQavJ1pZ0hIoC/jfqNoMLURIMhKzINIWypNM17puug= modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= diff --git a/internal/api/rest/api.go b/internal/api/rest/api.go deleted file mode 100644 index 2d0f6ce..0000000 --- a/internal/api/rest/api.go +++ /dev/null @@ -1,23 +0,0 @@ -package rest - -import ( - "net/http" - - "github.com/CameronXie/access-control-explorer/internal/api/rest/middlewares" -) - -type RouterConfig struct { - SignInHandler http.Handler - ResourceHandler http.Handler - AuthorisationMiddleware middlewares.Middleware -} - -// NewMuxWithHandlers initializes a new HTTP mux with routes defined by the given RouterConfig. -func NewMuxWithHandlers(cfg *RouterConfig) *http.ServeMux { - router := http.NewServeMux() - - router.Handle("POST /auth/signin", cfg.SignInHandler) - router.Handle("GET /api/resources", cfg.AuthorisationMiddleware.Handle(cfg.ResourceHandler)) - - return router -} diff --git a/internal/api/rest/handlers/hardcoded_resources.go b/internal/api/rest/handlers/hardcoded_resources.go deleted file mode 100644 index 307c593..0000000 --- a/internal/api/rest/handlers/hardcoded_resources.go +++ /dev/null @@ -1,31 +0,0 @@ -package handlers - -import ( - "net/http" - - "github.com/CameronXie/access-control-explorer/internal/api/rest/response" -) - -// Resource represents a resource entity with an ID, name, and active status. -type Resource struct { - ID int `json:"id"` - Name string `json:"name"` - Active bool `json:"active"` -} - -// HardcodedResourcesHandler serves hardcoded resources via HTTP in JSON format. -type HardcodedResourcesHandler struct { - resources []Resource -} - -// ServeHTTP handles HTTP requests by responding with a JSON representation of the hardcoded resources. -func (h *HardcodedResourcesHandler) ServeHTTP(w http.ResponseWriter, _ *http.Request) { - response.JSONResponse(w, http.StatusOK, map[string]any{"data": h.resources}) -} - -// NewHardcodedResourcesHandler creates a new HTTP handler that serves a JSON representation of hardcoded resources. -func NewHardcodedResourcesHandler(resources []Resource) http.Handler { - return &HardcodedResourcesHandler{ - resources: resources, - } -} diff --git a/internal/api/rest/handlers/signin.go b/internal/api/rest/handlers/signin.go deleted file mode 100644 index a436292..0000000 --- a/internal/api/rest/handlers/signin.go +++ /dev/null @@ -1,97 +0,0 @@ -package handlers - -import ( - "encoding/json" - "log/slog" - "net/http" - "time" - - "github.com/golang-jwt/jwt/v5" - - "github.com/CameronXie/access-control-explorer/internal/api/rest/response" - "github.com/CameronXie/access-control-explorer/internal/authn" - "github.com/CameronXie/access-control-explorer/internal/keyfetcher" -) - -const ( - tokenExpirationDuration = time.Hour - invalidRequestBodyMessage = "invalid request body" - invalidUsernameOrPasswordMessage = "invalid username or password" - internalServerErrorMessage = "internal server error" -) - -type SignInRequest struct { - Username string `json:"username"` - Password string `json:"password"` -} - -// SignInHandler processes user sign-in requests, authenticates credentials, and generates JWT tokens. -type SignInHandler struct { - authenticator authn.Authenticator - privateKeyFetcher keyfetcher.PrivateKeyFetcher - logger *slog.Logger -} - -// ServeHTTP handles HTTP requests for user sign-in, authenticates users and generates JWT tokens on successful login. -func (h *SignInHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - req := new(SignInRequest) - decodeErr := json.NewDecoder(r.Body).Decode(req) - if decodeErr != nil { - response.JSONErrorResponse(w, http.StatusBadRequest, invalidRequestBodyMessage) - return - } - - h.logger.With("username", req.Username) - authenticatedUser, authError := h.authenticator.Authenticate(req.Username, req.Password) - if authError != nil { - h.logger.ErrorContext(r.Context(), "failed to authenticate user", "error", authError) - response.JSONErrorResponse(w, http.StatusUnauthorized, invalidUsernameOrPasswordMessage) - return - } - - token, jwtError := h.generateJWT(authenticatedUser.Username) - if jwtError != nil { - h.logger.ErrorContext(r.Context(), "failed to generate JWT", "error", jwtError) - response.JSONErrorResponse(w, http.StatusInternalServerError, internalServerErrorMessage) - return - } - - response.JSONResponse(w, http.StatusOK, map[string]string{"token": token}) -} - -// generateJWT generates a JSON Web Token (JWT) for the given username with RS512 signing method and 1-hour expiration. -func (h *SignInHandler) generateJWT(username string) (string, error) { - token := jwt.NewWithClaims( - jwt.SigningMethodRS512, - jwt.MapClaims{ - "sub": username, - "iat": time.Now().Unix(), - "exp": time.Now().Add(tokenExpirationDuration).Unix(), - }, - ) - - privateKey, jwtError := h.privateKeyFetcher.FetchPrivateKey() - if jwtError != nil { - return "", jwtError - } - - tokenString, signError := token.SignedString(privateKey) - if signError != nil { - return "", signError - } - - return tokenString, nil -} - -// NewSignInHandler creates a new HTTP handler for user sign-in, using the provided authenticator and private key fetcher. -func NewSignInHandler( - authenticator authn.Authenticator, - privateKeyFetcher keyfetcher.PrivateKeyFetcher, - logger *slog.Logger, -) http.Handler { - return &SignInHandler{ - authenticator: authenticator, - privateKeyFetcher: privateKeyFetcher, - logger: logger, - } -} diff --git a/internal/api/rest/handlers/signin_test.go b/internal/api/rest/handlers/signin_test.go deleted file mode 100644 index 400b3fd..0000000 --- a/internal/api/rest/handlers/signin_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package handlers - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "errors" - "fmt" - "log/slog" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/CameronXie/access-control-explorer/internal/authn" -) - -// mockAuthenticator is a mock implementation of the Authenticator interface. -type mockAuthenticator struct { - mock.Mock -} - -func (m *mockAuthenticator) Authenticate(username, password string) (*authn.User, error) { - args := m.Called(username, password) - return args.Get(0).(*authn.User), args.Error(1) -} - -// mockPrivateKeyFetcher is a mock implementation of the PrivateKeyFetcher interface. -type mockPrivateKeyFetcher struct { - mock.Mock -} - -func (m *mockPrivateKeyFetcher) FetchPrivateKey() (*rsa.PrivateKey, error) { - args := m.Called() - return args.Get(0).(*rsa.PrivateKey), args.Error(1) -} - -// Helper function to generate a fake RSA private key. -func generateFakeRSAPrivateKey() (*rsa.PrivateKey, error) { - return rsa.GenerateKey(rand.Reader, 2048) -} - -func TestSignInHandler_ServeHTTP(t *testing.T) { - cases := map[string]struct { - requestBody string - mockAuthResult *authn.User - mockAuthError error - mockKeyError error - expectedStatus int - expectedMessage string - expectedLog map[string]string - }{ - "Should Return 200 and Token on Successful Authentication": { - requestBody: `{"username": "testuser", "password": "password"}`, - mockAuthResult: &authn.User{Username: "testuser"}, - mockAuthError: nil, - mockKeyError: nil, - expectedStatus: http.StatusOK, - }, - "Should Return 400 on Invalid Request Body": { - requestBody: "invalid", - mockAuthResult: nil, - mockAuthError: nil, - mockKeyError: nil, - expectedStatus: http.StatusBadRequest, - expectedMessage: invalidRequestBodyMessage, - }, - "Should Return 401 on Authentication Failure": { - requestBody: `{"username": "testuser", "password": "wrongpassword"}`, - mockAuthResult: nil, - mockAuthError: errors.New("auth failed"), - mockKeyError: nil, - expectedStatus: http.StatusUnauthorized, - expectedMessage: invalidUsernameOrPasswordMessage, - expectedLog: map[string]string{ - "level": "ERROR", - "msg": "failed to authenticate user", - "error": "auth failed", - }, - }, - "Should Return 500 on Key Fetch Failure": { - requestBody: `{"username": "testuser", "password": "password"}`, - mockAuthResult: &authn.User{Username: "testuser"}, - mockAuthError: nil, - mockKeyError: errors.New("key fetch failed"), - expectedStatus: http.StatusInternalServerError, - expectedMessage: internalServerErrorMessage, - expectedLog: map[string]string{ - "level": "ERROR", - "msg": "failed to generate JWT", - "error": "key fetch failed", - }, - }, - } - - for name, tc := range cases { - t.Run(name, func(t *testing.T) { - var buf bytes.Buffer - mockAuth := new(mockAuthenticator) - mockKeyFetcher := new(mockPrivateKeyFetcher) - logHandler := slog.NewJSONHandler(&buf, nil) - handler := NewSignInHandler(mockAuth, mockKeyFetcher, slog.New(logHandler)) - - mockAuth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.mockAuthResult, tc.mockAuthError) - - var privateKey *rsa.PrivateKey - if tc.mockKeyError == nil { - key, err := generateFakeRSAPrivateKey() - if err != nil { - t.Fatalf("failed to generate fake RSA private key: %v", err) - } - privateKey = key - } - - mockKeyFetcher.On("FetchPrivateKey").Return(privateKey, tc.mockKeyError) - - w := httptest.NewRecorder() - handler.ServeHTTP( - w, - httptest.NewRequest( - http.MethodPost, - "/", - bytes.NewBufferString(tc.requestBody), - ), - ) - - assert.Equal(t, tc.expectedStatus, w.Code) - - if tc.expectedLog != nil { - log := buf.String() - for k, v := range tc.expectedLog { - assert.Contains(t, log, fmt.Sprintf("%q:%q", k, v)) - } - } - - body := w.Body.String() - if tc.expectedStatus == http.StatusOK { - assert.Contains(t, body, "token") - return - } - - assert.Contains(t, body, tc.expectedMessage) - }) - } -} diff --git a/internal/api/rest/middlewares/jwt_authorization.go b/internal/api/rest/middlewares/jwt_authorization.go deleted file mode 100644 index b3e5f80..0000000 --- a/internal/api/rest/middlewares/jwt_authorization.go +++ /dev/null @@ -1,115 +0,0 @@ -package middlewares - -import ( - "errors" - "log/slog" - "net/http" - "strings" - - "github.com/golang-jwt/jwt/v5" - - "github.com/CameronXie/access-control-explorer/internal/api/rest/response" - "github.com/CameronXie/access-control-explorer/internal/enforcer" - "github.com/CameronXie/access-control-explorer/internal/keyfetcher" -) - -const ( - authHeaderMissingMessage = "authorization header missing" - invalidAuthHeaderFormatMessage = "invalid authorization header format" - internalServerErrorMessage = "internal server error" - invalidTokenMessage = "invalid token" - forbiddenMessage = "forbidden" -) - -// JWTAuthorizationMiddleware handles JWT token authorization, validating tokens and enforcing access policies. -// enforcer is an interface for enforcing access policies. -// publicKeyFetcher is an interface for fetching the public key used to validate JWT tokens. -type JWTAuthorizationMiddleware struct { - enforcer enforcer.Enforcer - publicKeyFetcher keyfetcher.PublicKeyFetcher - logger *slog.Logger -} - -// Handle processes incoming HTTP requests, applying JWT authorization by validating tokens and enforcing access policies. -func (m *JWTAuthorizationMiddleware) Handle(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - response.JSONErrorResponse(w, http.StatusUnauthorized, authHeaderMissingMessage) - return - } - - token, err := extractToken(authHeader) - if err != nil { - m.logger.ErrorContext(r.Context(), "failed to extract token", "error", err) - response.JSONErrorResponse(w, http.StatusUnauthorized, invalidAuthHeaderFormatMessage) - return - } - - publicKey, err := m.publicKeyFetcher.FetchPublicKey() - if err != nil { - m.logger.ErrorContext(r.Context(), "failed to fetch public key", "error", err) - response.JSONErrorResponse(w, http.StatusInternalServerError, internalServerErrorMessage) - return - } - - claims := new(jwt.MapClaims) - _, err = jwt.ParseWithClaims(token, claims, func(_ *jwt.Token) (any, error) { - return publicKey, nil - }) - - if err != nil { - m.logger.ErrorContext(r.Context(), "failed to parse token", "error", err) - response.JSONErrorResponse(w, http.StatusUnauthorized, invalidTokenMessage) - return - } - - sub, err := claims.GetSubject() - if sub == "" || err != nil { - m.logger.ErrorContext(r.Context(), "failed to get subject from token claims") - response.JSONErrorResponse(w, http.StatusUnauthorized, invalidTokenMessage) - return - } - - ok, err := m.enforcer.Enforce( - r.Context(), - &enforcer.AccessRequest{ - Subject: sub, - Resource: r.URL.Path, - Action: r.Method, - }, - ) - - if err != nil || !ok { - m.logger.ErrorContext(r.Context(), "failed to enforce access policy", "error", err) - response.JSONErrorResponse(w, http.StatusForbidden, forbiddenMessage) - return - } - - next.ServeHTTP(w, r) - }) -} - -// extractToken extracts a Bearer token from the Authorization header. -// Returns the extracted token or an error if the header format is invalid. -func extractToken(authHeader string) (string, error) { - parts := strings.Split(authHeader, " ") - if len(parts) != 2 || parts[0] != "Bearer" || parts[1] == "" { - return "", errors.New("invalid authorization header format") - } - - return parts[1], nil -} - -// NewJWTAuthorizationMiddleware returns a new instance of JWTAuthorizationMiddleware with the given enforcer and public key fetcher. -func NewJWTAuthorizationMiddleware( - e enforcer.Enforcer, - publicKeyFetcher keyfetcher.PublicKeyFetcher, - logger *slog.Logger, -) Middleware { - return &JWTAuthorizationMiddleware{ - enforcer: e, - publicKeyFetcher: publicKeyFetcher, - logger: logger, - } -} diff --git a/internal/api/rest/middlewares/jwt_authorization_test.go b/internal/api/rest/middlewares/jwt_authorization_test.go deleted file mode 100644 index de868e8..0000000 --- a/internal/api/rest/middlewares/jwt_authorization_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package middlewares - -import ( - "bytes" - "context" - "crypto/rand" - "crypto/rsa" - "errors" - "fmt" - "log/slog" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - - "github.com/CameronXie/access-control-explorer/internal/enforcer" - "github.com/CameronXie/access-control-explorer/internal/keyfetcher" - - "github.com/stretchr/testify/assert" -) - -type mockEnforcer struct { - enforcer.Enforcer - enforceReturnVal bool - enforceReturnError error -} - -func (m *mockEnforcer) Enforce(_ context.Context, _ *enforcer.AccessRequest) (bool, error) { - return m.enforceReturnVal, m.enforceReturnError -} - -type mockPublicKeyFetcher struct { - keyfetcher.PublicKeyFetcher - publicKey *rsa.PublicKey - fetchReturnError error -} - -func (m *mockPublicKeyFetcher) FetchPublicKey() (*rsa.PublicKey, error) { - if m.fetchReturnError != nil { - return nil, m.fetchReturnError - } - return m.publicKey, nil -} - -func TestJWTAuthorizationMiddleware_Handle(t *testing.T) { - privateKey, publicKey, err := generateKeyPair() - assert.NoError(t, err) - - validToken, err := generateValidToken(privateKey, jwt.MapClaims{ - "sub": "user123", - "exp": time.Now().Add(time.Hour).Unix(), - }) - assert.NoError(t, err) - - tokenWithoutSubClaim, err := generateValidToken(privateKey, jwt.MapClaims{ - "exp": time.Now().Add(time.Hour).Unix(), - }) - assert.NoError(t, err) - - cases := map[string]struct { - authorizationHeader string - expectedStatusCode int - expectedMessage string - expectedLog map[string]string - enforceReturnVal bool - enforceReturnError error - fetchReturnError error - }{ - "HappyPath": { - authorizationHeader: fmt.Sprintf("Bearer %s", validToken), - expectedStatusCode: http.StatusOK, - enforceReturnVal: true, - }, - "InvalidToken": { - authorizationHeader: "Bearer invalidtoken", - expectedStatusCode: http.StatusUnauthorized, - expectedMessage: invalidTokenMessage, - expectedLog: map[string]string{ - "level": "ERROR", - "msg": "failed to parse token", - "error": "token is malformed: token contains an invalid number of segments", - }, - }, - "InvalidTokenFormat": { - authorizationHeader: validToken, - expectedStatusCode: http.StatusUnauthorized, - expectedMessage: invalidAuthHeaderFormatMessage, - expectedLog: map[string]string{ - "level": "ERROR", - "msg": "failed to extract token", - "error": "invalid authorization header format", - }, - }, - "AuthorizationHeaderMissing": { - expectedStatusCode: http.StatusUnauthorized, - expectedMessage: authHeaderMissingMessage, - }, - "SubClaimMissing": { - authorizationHeader: fmt.Sprintf("Bearer %s", tokenWithoutSubClaim), - expectedStatusCode: http.StatusUnauthorized, - expectedMessage: invalidTokenMessage, - expectedLog: map[string]string{ - "level": "ERROR", - "msg": "failed to get subject from token claims", - }, - }, - "EnforcerError": { - authorizationHeader: fmt.Sprintf("Bearer %s", validToken), - expectedStatusCode: http.StatusForbidden, - enforceReturnError: errors.New("some error"), - expectedMessage: forbiddenMessage, - expectedLog: map[string]string{ - "level": "ERROR", - "msg": "failed to enforce access policy", - "error": "some error", - }, - }, - "FetchPublicKeyError": { - authorizationHeader: fmt.Sprintf("Bearer %s", validToken), - expectedStatusCode: http.StatusInternalServerError, - fetchReturnError: errors.New("some error"), - expectedMessage: internalServerErrorMessage, - expectedLog: map[string]string{ - "level": "ERROR", - "msg": "failed to fetch public key", - "error": "some error", - }, - }, - } - - for name, tc := range cases { - t.Run(name, func(t *testing.T) { - var buf bytes.Buffer - e := &mockEnforcer{enforceReturnVal: tc.enforceReturnVal, enforceReturnError: tc.enforceReturnError} - k := &mockPublicKeyFetcher{publicKey: publicKey, fetchReturnError: tc.fetchReturnError} - h := slog.NewJSONHandler(&buf, nil) - middleware := NewJWTAuthorizationMiddleware(e, k, slog.New(h)) - - request := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - if tc.authorizationHeader != "" { - request.Header.Set("Authorization", tc.authorizationHeader) - } - w := httptest.NewRecorder() - - nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - }) - middleware.Handle(nextHandler).ServeHTTP(w, request) - - assert.Equal(t, tc.expectedStatusCode, w.Code) - if tc.expectedMessage != "" { - assert.Equal(t, fmt.Sprintf("{\"error\":%q}\n", tc.expectedMessage), w.Body.String()) - } - - if tc.expectedLog != nil { - log := buf.String() - for k, v := range tc.expectedLog { - assert.Contains(t, log, fmt.Sprintf("%q:%q", k, v)) - } - } - }) - } -} - -func TestExtractToken(t *testing.T) { - cases := map[string]struct { - input string - expected string - hasError bool - }{ - "valid token": {input: "Bearer tokenvalue", expected: "tokenvalue", hasError: false}, - "invalid token": {input: "tokenvalue", expected: "", hasError: true}, - "empty header": {input: "", expected: "", hasError: true}, - } - - for name, tc := range cases { - t.Run(name, func(t *testing.T) { - result, err := extractToken(tc.input) - assert.Equal(t, tc.expected, result) - if tc.hasError { - assert.NotNil(t, err) - } else { - assert.Nil(t, err) - } - }) - } -} - -func generateKeyPair() (*rsa.PrivateKey, *rsa.PublicKey, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - return privateKey, &privateKey.PublicKey, nil -} - -func generateValidToken(privateKey *rsa.PrivateKey, claims jwt.MapClaims) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - return token.SignedString(privateKey) -} diff --git a/internal/api/rest/middlewares/middleware.go b/internal/api/rest/middlewares/middleware.go deleted file mode 100644 index ff299d9..0000000 --- a/internal/api/rest/middlewares/middleware.go +++ /dev/null @@ -1,7 +0,0 @@ -package middlewares - -import "net/http" - -type Middleware interface { - Handle(next http.Handler) http.Handler -} diff --git a/internal/api/rest/response/json_response.go b/internal/api/rest/response/json_response.go deleted file mode 100644 index 738e573..0000000 --- a/internal/api/rest/response/json_response.go +++ /dev/null @@ -1,18 +0,0 @@ -package response - -import ( - "encoding/json" - "net/http" -) - -// JSONResponse writes the given data as a JSON response with the specified status code. -func JSONResponse(w http.ResponseWriter, statusCode int, data any) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(statusCode) - _ = json.NewEncoder(w).Encode(data) -} - -// JSONErrorResponse writes an error message as a JSON response with the specified status code. -func JSONErrorResponse(w http.ResponseWriter, statusCode int, message string) { - JSONResponse(w, statusCode, map[string]string{"error": message}) -} diff --git a/internal/api/rest/response/json_response_test.go b/internal/api/rest/response/json_response_test.go deleted file mode 100644 index fbee32f..0000000 --- a/internal/api/rest/response/json_response_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package response - -import ( - "io" - "net/http" - "net/http/httptest" - "testing" -) - -func TestJSONResponse(t *testing.T) { - cases := map[string]struct { - status int - data any - expected string - }{ - "Struct": {http.StatusOK, struct{ Name string }{Name: "test"}, `{"Name":"test"}`}, - "String": {http.StatusOK, "test", `"test"`}, - } - - for name, tc := range cases { - t.Run(name, func(t *testing.T) { - rr := httptest.NewRecorder() - JSONResponse(rr, tc.status, tc.data) - checkResponse(t, rr, tc.status, tc.expected) - }) - } -} - -func TestJSONErrorResponse(t *testing.T) { - cases := map[string]struct { - status int - message string - expected string - }{ - "Valid": {http.StatusOK, "test data", `{"error":"test data"}`}, - "NotFound": {http.StatusNotFound, "not found", `{"error":"not found"}`}, - } - - for name, tc := range cases { - t.Run(name, func(t *testing.T) { - rr := httptest.NewRecorder() - JSONErrorResponse(rr, tc.status, tc.message) - checkResponse(t, rr, tc.status, tc.expected) - }) - } -} - -func checkResponse(t *testing.T, rr *httptest.ResponseRecorder, expectedStatus int, expectedBody string) { - result := rr.Result() - defer result.Body.Close() - - body, _ := io.ReadAll(result.Body) - - if result.StatusCode != expectedStatus { - t.Errorf("Expected response code %v. Got %v", expectedStatus, result.StatusCode) - } - if string(body) != expectedBody+"\n" { - t.Errorf("Expected response %s. Got %s", expectedBody, string(body)) - } -} diff --git a/internal/authn/authenticator.go b/internal/authn/authenticator.go deleted file mode 100644 index db01596..0000000 --- a/internal/authn/authenticator.go +++ /dev/null @@ -1,9 +0,0 @@ -package authn - -type User struct { - Username string `json:"username"` -} - -type Authenticator interface { - Authenticate(username, password string) (*User, error) -} diff --git a/internal/authn/hardcoded_authenticator.go b/internal/authn/hardcoded_authenticator.go deleted file mode 100644 index 39ddac3..0000000 --- a/internal/authn/hardcoded_authenticator.go +++ /dev/null @@ -1,28 +0,0 @@ -package authn - -import "errors" - -type HardcodedAuthenticator struct { - users map[string]string -} - -func (a *HardcodedAuthenticator) Authenticate(username, password string) (*User, error) { - if a.users == nil { - return nil, errors.New("authentication failed") - } - - pass, ok := a.users[username] - - // This function is for demonstration purposes only and should not be used in production. - // For production, please implement a secure authentication mechanism, such as - // verifying credentials against a database and using proper hashing and salting techniques. - if !ok || pass != password { - return nil, errors.New("password mismatch") - } - - return &User{Username: username}, nil -} - -func NewHardcodedAuthenticator(users map[string]string) Authenticator { - return &HardcodedAuthenticator{users: users} -} diff --git a/internal/decisionmaker/casbin/decisionmaker.go b/internal/decisionmaker/casbin/decisionmaker.go deleted file mode 100644 index 6ff9b32..0000000 --- a/internal/decisionmaker/casbin/decisionmaker.go +++ /dev/null @@ -1,43 +0,0 @@ -package casbin - -import ( - "context" - - "github.com/casbin/casbin/v2" - "github.com/casbin/casbin/v2/model" - "github.com/casbin/casbin/v2/persist" - - "github.com/CameronXie/access-control-explorer/internal/decisionmaker" -) - -type decisionMaker struct { - enforcer casbin.IEnforcer -} - -// MakeDecision evaluates a decision request based on provided subject, resource, and action using the enforcer. -// It first loads the latest policy and then enforces the decision based on the request parameters. -// Returns a boolean indicating the enforcement result and an error if any occurs during policy loading or decision enforcement. -func (d *decisionMaker) MakeDecision(_ context.Context, req *decisionmaker.DecisionRequest) (bool, error) { - err := d.enforcer.LoadPolicy() - if err != nil { - return false, err - } - - return d.enforcer.Enforce(req.Subject, req.Resource, req.Action) -} - -// NewDecisionMaker creates a new instance of DecisionMaker using the provided Casbin configuration and policy repository adapter. -// It returns a DecisionMaker for processing decision requests, or an error if model creation or enforcer initialization fails. -func NewDecisionMaker(config string, policyRepo persist.Adapter) (decisionmaker.DecisionMaker, error) { - m, err := model.NewModelFromString(config) - if err != nil { - return nil, err - } - - enforcer, err := casbin.NewEnforcer(m, policyRepo) - if err != nil { - return nil, err - } - - return &decisionMaker{enforcer: enforcer}, nil -} diff --git a/internal/decisionmaker/casbin/decisionmaker_test.go b/internal/decisionmaker/casbin/decisionmaker_test.go deleted file mode 100644 index 8b82e59..0000000 --- a/internal/decisionmaker/casbin/decisionmaker_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package casbin - -import ( - "context" - "testing" - - "github.com/casbin/casbin/v2" - fileadapter "github.com/casbin/casbin/v2/persist/file-adapter" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/CameronXie/access-control-explorer/internal/decisionmaker" -) - -const ( - policyPath = "testdata/policy.csv" -) - -// mockEnforcer is a mock implementation of the casbin.IEnforcer interface used for testing purposes. -type mockEnforcer struct { - casbin.IEnforcer - mock.Mock -} - -func (e *mockEnforcer) LoadPolicy() error { - args := e.Called() - return args.Error(0) -} - -func (e *mockEnforcer) Enforce(rvals ...any) (bool, error) { - args := e.Called(rvals...) - return args.Bool(0), args.Error(1) -} - -func TestDecisionMaker_MakeDecision(t *testing.T) { - request := &decisionmaker.DecisionRequest{ - Resource: "resource", - Action: "action", - Subject: "subject", - } - - enforcer := new(mockEnforcer) - enforcer.On("LoadPolicy").Return(nil) - enforcer.On( - "Enforce", - request.Subject, request.Resource, request.Action, - ).Return(true, nil) - - decisionMaker := decisionMaker{enforcer: enforcer} - decision, err := decisionMaker.MakeDecision(context.TODO(), request) - - assert.True(t, decision) - assert.NoError(t, err) - enforcer.AssertCalled(t, "Enforce", request.Subject, request.Resource, request.Action) - enforcer.AssertNumberOfCalls(t, "LoadPolicy", 1) - enforcer.AssertNumberOfCalls(t, "Enforce", 1) -} - -func TestNewDecisionMaker(t *testing.T) { - d, err := NewDecisionMaker(getConfig(), fileadapter.NewAdapter(policyPath)) - assert.NoError(t, err) - assert.NotNil(t, d) - - cases := map[string]struct { - request *decisionmaker.DecisionRequest - expectDecision bool - expectError error - }{ - "allow": { - request: &decisionmaker.DecisionRequest{ - Subject: "alice", - Action: "write", - Resource: "data1", - }, - expectDecision: true, - }, - "deny": { - request: &decisionmaker.DecisionRequest{ - Subject: "bob", - Action: "write", - Resource: "data1", - }, - expectDecision: false, - }, - } - - for name, tc := range cases { - t.Run(name, func(t *testing.T) { - decision, err := d.MakeDecision(context.TODO(), tc.request) - assert.Equal(t, tc.expectDecision, decision) - assert.Equal(t, tc.expectError, err) - }) - } -} - -func getConfig() string { - return ` -[request_definition] -r = sub, obj, act - -[policy_definition] -p = sub, obj, act - -[role_definition] -g = _, _ - -[policy_effect] -e = some(where (p.eft == allow)) - -[matchers] -m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act -` -} diff --git a/internal/decisionmaker/casbin/testdata/policy.csv b/internal/decisionmaker/casbin/testdata/policy.csv deleted file mode 100644 index d98c3eb..0000000 --- a/internal/decisionmaker/casbin/testdata/policy.csv +++ /dev/null @@ -1,6 +0,0 @@ -p, admin, data1, read -p, admin, data1, write -p, user, data1, read - -g, alice, admin -g, bob, user diff --git a/internal/decisionmaker/decisionmaker.go b/internal/decisionmaker/decisionmaker.go deleted file mode 100644 index 12d0398..0000000 --- a/internal/decisionmaker/decisionmaker.go +++ /dev/null @@ -1,13 +0,0 @@ -package decisionmaker - -import "context" - -type DecisionRequest struct { - Subject string - Resource string - Action string -} - -type DecisionMaker interface { - MakeDecision(ctx context.Context, req *DecisionRequest) (bool, error) -} diff --git a/internal/decisionmaker/opa/decisionmaker.go b/internal/decisionmaker/opa/decisionmaker.go deleted file mode 100644 index e204858..0000000 --- a/internal/decisionmaker/opa/decisionmaker.go +++ /dev/null @@ -1,65 +0,0 @@ -package opa - -import ( - "context" - "fmt" - - "github.com/open-policy-agent/opa/rego" - - "github.com/CameronXie/access-control-explorer/internal/decisionmaker" - "github.com/CameronXie/access-control-explorer/internal/infoprovider" - "github.com/CameronXie/access-control-explorer/internal/policyretriever" -) - -const ( - moduleName = "decisionmaker" -) - -type decisionMaker struct { - policyRetriever policyretriever.PolicyRetriever - infoProvider infoprovider.InfoProvider - query string -} - -// MakeDecision evaluates a policy against the given decision request and returns whether the action is allowed or not. -func (d *decisionMaker) MakeDecision(ctx context.Context, req *decisionmaker.DecisionRequest) (bool, error) { - policy, err := d.policyRetriever.GetPolicy() - if err != nil { - return false, fmt.Errorf("failed to get policy: %w", err) - } - - query, err := rego.New(rego.Module(moduleName, policy), rego.Query(d.query)).PrepareForEval(ctx) - if err != nil { - return false, fmt.Errorf("failed to prepare query: %w", err) - } - - roles, err := d.infoProvider.GetRoles(req.Subject) - if err != nil { - return false, fmt.Errorf("failed to get roles: %w", err) - } - - result, err := query.Eval(ctx, rego.EvalInput(map[string]any{ - "roles": roles, - "action": req.Action, - "resource": req.Resource, - })) - - if len(result) == 0 || result[0].Expressions[0] == nil || err != nil { - return false, fmt.Errorf("failed to evaluate query: %w", err) - } - - return result[0].Expressions[0].Value.(bool), nil -} - -// NewDecisionMaker initializes a DecisionMaker with the provided PolicyRetriever, InfoProvider, and Rego query. -func NewDecisionMaker( - policyRetriever policyretriever.PolicyRetriever, - infoProvider infoprovider.InfoProvider, - query string, -) decisionmaker.DecisionMaker { - return &decisionMaker{ - policyRetriever: policyRetriever, - infoProvider: infoProvider, - query: query, - } -} diff --git a/internal/decisionmaker/opa/decisionmaker_test.go b/internal/decisionmaker/opa/decisionmaker_test.go deleted file mode 100644 index 49b3203..0000000 --- a/internal/decisionmaker/opa/decisionmaker_test.go +++ /dev/null @@ -1,125 +0,0 @@ -package opa - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/CameronXie/access-control-explorer/internal/decisionmaker" -) - -type MockPolicyRetriever struct { - mock.Mock -} - -func (m *MockPolicyRetriever) GetPolicy() (string, error) { - args := m.Called() - return args.String(0), args.Error(1) -} - -type MockInfoProvider struct { - mock.Mock -} - -func (m *MockInfoProvider) GetRoles(id string) ([]string, error) { - args := m.Called(id) - return args.Get(0).([]string), args.Error(1) -} - -func TestMakeDecision(t *testing.T) { - policy := getPolicy() - roles := []string{"admin"} - - cases := map[string]struct { - mockPolicy string - errPolicy error - mockRoles []string - errRoles error - expected bool - wantErr string - }{ - "Success": { - mockPolicy: policy, - errPolicy: nil, - mockRoles: roles, - errRoles: nil, - expected: true, - }, - "Policy retriever error": { - mockPolicy: "", - errPolicy: errors.New("some error"), - mockRoles: roles, - errRoles: nil, - expected: false, - wantErr: "failed to get policy: some error", - }, - "Query initialisation error": { - mockPolicy: "", - errPolicy: nil, - mockRoles: roles, - errRoles: nil, - expected: false, - wantErr: "failed to prepare query: 1 error occurred: decisionmaker:0: rego_parse_error: empty module", - }, - "Information provider error": { - mockPolicy: policy, - errPolicy: nil, - mockRoles: nil, - errRoles: errors.New("some error"), - expected: false, - wantErr: "failed to get roles: some error", - }, - } - - for name, tc := range cases { - t.Run(name, func(t *testing.T) { - policyRetrieverMock := new(MockPolicyRetriever) - infoProviderMock := new(MockInfoProvider) - decisionMaker := NewDecisionMaker(policyRetrieverMock, infoProviderMock, "data.rbac.allow") - request := &decisionmaker.DecisionRequest{ - Subject: "testUser", - Action: "read", - Resource: "database123", - } - - policyRetrieverMock.On("GetPolicy").Return(tc.mockPolicy, tc.errPolicy) - infoProviderMock.On("GetRoles", request.Subject).Return(tc.mockRoles, tc.errRoles) - - got, err := decisionMaker.MakeDecision(context.TODO(), request) - - if tc.wantErr != "" { - assert.EqualError(t, err, tc.wantErr) - } else { - assert.NoError(t, err) - } - - assert.Equal(t, tc.expected, got) - }) - } -} - -func getPolicy() string { - return ` -package rbac - -role_permissions := { - "admin": [{"action": "read", "resource": "database123"}], -} - -default allow = false - -allow { - # for each role in that list - r := input.roles[_] - # lookup the permissions list for role r - permissions := role_permissions[r] - # for each permission - p := permissions[_] - # check if the permission granted to r matches the user's request - p == {"action": input.action, "resource": input.resource} -} -` -} diff --git a/internal/enforcer/enforcer.go b/internal/enforcer/enforcer.go deleted file mode 100644 index 0d16572..0000000 --- a/internal/enforcer/enforcer.go +++ /dev/null @@ -1,37 +0,0 @@ -package enforcer - -import ( - "context" - "strings" - - "github.com/CameronXie/access-control-explorer/internal/decisionmaker" -) - -type Enforcer interface { - Enforce(ctx context.Context, req *AccessRequest) (bool, error) -} - -type AccessRequest struct { - Subject string - Resource string - Action string -} - -type enforcer struct { - decisionMaker decisionmaker.DecisionMaker -} - -func (e *enforcer) Enforce(ctx context.Context, req *AccessRequest) (bool, error) { - return e.decisionMaker.MakeDecision( - ctx, - &decisionmaker.DecisionRequest{ - Subject: strings.ToLower(req.Subject), - Resource: strings.ToLower(req.Resource), - Action: strings.ToLower(req.Action), - }, - ) -} - -func NewEnforcer(decisionMaker decisionmaker.DecisionMaker) Enforcer { - return &enforcer{decisionMaker: decisionMaker} -} diff --git a/internal/infoprovider/infoprovider.go b/internal/infoprovider/infoprovider.go deleted file mode 100644 index 9abec0c..0000000 --- a/internal/infoprovider/infoprovider.go +++ /dev/null @@ -1,5 +0,0 @@ -package infoprovider - -type InfoProvider interface { - GetRoles(id string) ([]string, error) -} diff --git a/internal/infoprovider/opa/hardcoded_infoprovider.go b/internal/infoprovider/opa/hardcoded_infoprovider.go deleted file mode 100644 index 5188279..0000000 --- a/internal/infoprovider/opa/hardcoded_infoprovider.go +++ /dev/null @@ -1,26 +0,0 @@ -package opa - -import ( - "fmt" - - "github.com/CameronXie/access-control-explorer/internal/infoprovider" -) - -type hardcodedInfoProvider struct { - users map[string][]string -} - -// GetRoles returns a slice of roles for a given user ID. -// It returns an error if the user ID is not found. -func (p *hardcodedInfoProvider) GetRoles(id string) ([]string, error) { - if roles, ok := p.users[id]; ok { - return roles, nil - } - - return nil, fmt.Errorf("user %s not found", id) -} - -// NewHardcodedInfoProvider initializes a new InfoProvider with a map of users and their corresponding roles. -func NewHardcodedInfoProvider(users map[string][]string) infoprovider.InfoProvider { - return &hardcodedInfoProvider{users: users} -} diff --git a/internal/policyretriever/opa/hardcoded_policyretriever.go b/internal/policyretriever/opa/hardcoded_policyretriever.go deleted file mode 100644 index 83927a7..0000000 --- a/internal/policyretriever/opa/hardcoded_policyretriever.go +++ /dev/null @@ -1,19 +0,0 @@ -package opa - -import "github.com/CameronXie/access-control-explorer/internal/policyretriever" - -type hardcodedPolicyRetriever struct { - policy string -} - -// GetPolicy retrieves the hardcoded policy as a string and returns it along with any potential error. -func (p *hardcodedPolicyRetriever) GetPolicy() (string, error) { - return p.policy, nil -} - -// NewHardcodedPolicyRetriever creates a PolicyRetriever with a provided hardcoded policy string. -func NewHardcodedPolicyRetriever(policy string) policyretriever.PolicyRetriever { - return &hardcodedPolicyRetriever{ - policy: policy, - } -} diff --git a/internal/policyretriever/policyretriever.go b/internal/policyretriever/policyretriever.go deleted file mode 100644 index fdd1b2e..0000000 --- a/internal/policyretriever/policyretriever.go +++ /dev/null @@ -1,5 +0,0 @@ -package policyretriever - -type PolicyRetriever interface { - GetPolicy() (string, error) -}