diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 5e02acac..c8ec49bf 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -94,6 +94,14 @@ def configure_app( default_public=settings.default_public, public_endpoints=settings.public_endpoints, private_endpoints=settings.private_endpoints, + items_filter_path=( + settings.items_filter_path if settings.items_filter else None + ), + collections_filter_path=( + settings.collections_filter_path + if settings.collections_filter + else None + ), oidc_discovery_url=str(settings.oidc_discovery_url), ) diff --git a/src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py b/src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py index dfdb8982..23b63813 100644 --- a/src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py +++ b/src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py @@ -3,7 +3,7 @@ import logging import re from dataclasses import dataclass, field -from typing import Any +from typing import Any, Optional from urllib.parse import urlparse from starlette.datastructures import Headers @@ -35,6 +35,9 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware): "https://stac-extensions.github.io/authentication/v1.1.0/schema.json" ) + items_filter_path: Optional[str] = None + collections_filter_path: Optional[str] = None + json_content_type_expr: str = r"application/(geo\+)?json" def should_transform_response(self, request: Request, scope: Scope) -> bool: @@ -94,8 +97,10 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An private_endpoints=self.private_endpoints, public_endpoints=self.public_endpoints, default_public=self.default_public, + items_filter_path=self.items_filter_path, + collections_filter_path=self.collections_filter_path, ) - if match.is_private: + if match.uses_auth: auth_refs = ensure_type(link, "auth:refs", list) auth_refs.append(self.auth_scheme_name) diff --git a/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py b/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py index 1ed000c8..5644fb8d 100644 --- a/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py +++ b/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py @@ -100,7 +100,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: try: payload = self.validate_token( request.headers.get("Authorization"), - auto_error=match.is_private, + auto_error=match.uses_auth, required_scopes=match.required_scopes, ) diff --git a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py index 35b76eda..8f42955b 100644 --- a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py +++ b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py @@ -27,6 +27,7 @@ class OpenApiMiddleware(JsonResponseMiddleware): root_path: str = "" auth_scheme_name: str = "oidcAuth" auth_scheme_override: Optional[dict] = None + items_filter_path: Optional[str] = None collections_filter_path: Optional[str] = None @@ -79,15 +80,10 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An self.private_endpoints, self.public_endpoints, self.default_public, + items_filter_path=self.items_filter_path, + collections_filter_path=self.collections_filter_path, ) - if match.is_private or self._path_has_filter(path): + if match.uses_auth: security = ensure_type(config, "security", list) security.append({self.auth_scheme_name: match.required_scopes}) return data - - def _path_has_filter(self, path: str) -> bool: - """Check if a path matches any configured CQL2 filter path.""" - for filter_path in (self.items_filter_path, self.collections_filter_path): - if filter_path and re.match(filter_path, path): - return True - return False diff --git a/src/stac_auth_proxy/utils/requests.py b/src/stac_auth_proxy/utils/requests.py index 7b8a53ab..5bbfdb7c 100644 --- a/src/stac_auth_proxy/utils/requests.py +++ b/src/stac_auth_proxy/utils/requests.py @@ -56,34 +56,41 @@ def find_match( private_endpoints: EndpointMethods, public_endpoints: EndpointMethods, default_public: bool, + items_filter_path: Optional[str] = None, + collections_filter_path: Optional[str] = None, ) -> "MatchResult": """Check if the given path and method match any of the regex patterns and methods in the endpoints.""" primary_endpoints = private_endpoints if default_public else public_endpoints matched, required_scopes = _check_endpoint_match(path, method, primary_endpoints) if matched: return MatchResult( - is_private=default_public, + uses_auth=default_public, required_scopes=required_scopes, ) + # If we have filter paths configured, check those as well (these are always considered to use auth if they match, regardless of default_public) + for filter_path in [items_filter_path, collections_filter_path]: + if filter_path and re.match(filter_path, path): + return MatchResult(uses_auth=True) + # If default_public and no match found in private_endpoints, it's public if default_public: - return MatchResult(is_private=False) + return MatchResult(uses_auth=False) # If not default_public, check private_endpoints for required scopes matched, required_scopes = _check_endpoint_match(path, method, private_endpoints) if matched: - return MatchResult(is_private=True, required_scopes=required_scopes) + return MatchResult(uses_auth=True, required_scopes=required_scopes) # Default case: if not default_public and no explicit match, it's private - return MatchResult(is_private=True) + return MatchResult(uses_auth=True) @dataclass class MatchResult: """Result of a match between a path and method and a set of endpoints.""" - is_private: bool + uses_auth: bool required_scopes: Sequence[str] = field(default_factory=list) diff --git a/tests/test_auth_extension.py b/tests/test_auth_extension.py index ba829139..712933ee 100644 --- a/tests/test_auth_extension.py +++ b/tests/test_auth_extension.py @@ -250,6 +250,176 @@ def test_transform_json_with_invalid_stac_extensions_types( assert "auth:schemes" in transformed assert "test_auth" in transformed["auth:schemes"] + +class TestFilterPathAnnotation: + """Tests for auth:refs annotation when links match items/collections filter paths.""" + + @pytest.fixture + def middleware_with_items_filter(self, oidc_discovery_url): + """Middleware with items_filter_path configured.""" + return AuthenticationExtensionMiddleware( + app=None, + default_public=True, + private_endpoints=EndpointMethods(), + public_endpoints=EndpointMethods(), + oidc_discovery_url=oidc_discovery_url, + auth_scheme_name="test_auth", + items_filter_path=r"^/collections/[^/]+/items$", + ) + + @pytest.fixture + def middleware_with_collections_filter(self, oidc_discovery_url): + """Middleware with collections_filter_path configured.""" + return AuthenticationExtensionMiddleware( + app=None, + default_public=True, + private_endpoints=EndpointMethods(), + public_endpoints=EndpointMethods(), + oidc_discovery_url=oidc_discovery_url, + auth_scheme_name="test_auth", + collections_filter_path=r"^/collections$", + ) + + @pytest.fixture + def middleware_with_both_filters(self, oidc_discovery_url): + """Middleware with both filter paths configured.""" + return AuthenticationExtensionMiddleware( + app=None, + default_public=True, + private_endpoints=EndpointMethods(), + public_endpoints=EndpointMethods(), + oidc_discovery_url=oidc_discovery_url, + auth_scheme_name="test_auth", + items_filter_path=r"^/collections/[^/]+/items$", + collections_filter_path=r"^/collections$", + ) + + def test_items_link_annotated_when_items_filter_matches( + self, middleware_with_items_filter, request_scope + ): + """Links to items endpoints get auth:refs when items_filter_path matches.""" + request = Request(request_scope) + data = { + "stac_version": "1.0.0", + "type": "Collection", + "id": "test", + "description": "Test", + "links": [ + {"rel": "self", "href": "/collections/test"}, + {"rel": "items", "href": "/collections/test/items"}, + ], + } + + transformed = middleware_with_items_filter.transform_json(data, request) + + # The self link should NOT have auth:refs (no private endpoints, default_public=True) + self_link = next(link for link in transformed["links"] if link["rel"] == "self") + assert "auth:refs" not in self_link + + # The items link SHOULD have auth:refs because it matches items_filter_path + items_link = next( + link for link in transformed["links"] if link["rel"] == "items" + ) + assert "auth:refs" in items_link + assert "test_auth" in items_link["auth:refs"] + + def test_collections_link_annotated_when_collections_filter_matches( + self, middleware_with_collections_filter, request_scope + ): + """Links to collections endpoint get auth:refs when collections_filter_path matches.""" + request = Request(request_scope) + data = { + "stac_version": "1.0.0", + "id": "test-catalog", + "description": "Test catalog", + "links": [ + {"rel": "self", "href": "/"}, + {"rel": "data", "href": "/collections"}, + ], + } + + transformed = middleware_with_collections_filter.transform_json(data, request) + + self_link = next(link for link in transformed["links"] if link["rel"] == "self") + assert "auth:refs" not in self_link + + collections_link = next( + link for link in transformed["links"] if link["rel"] == "data" + ) + assert "auth:refs" in collections_link + assert "test_auth" in collections_link["auth:refs"] + + def test_both_filter_paths_annotated( + self, middleware_with_both_filters, request_scope + ): + """Links matching either filter path get auth:refs.""" + request = Request(request_scope) + data = { + "stac_version": "1.0.0", + "id": "test-catalog", + "description": "Test catalog", + "links": [ + {"rel": "self", "href": "/"}, + {"rel": "data", "href": "/collections"}, + {"rel": "items", "href": "/collections/test/items"}, + ], + } + + transformed = middleware_with_both_filters.transform_json(data, request) + + self_link = next(link for link in transformed["links"] if link["rel"] == "self") + assert "auth:refs" not in self_link + + collections_link = next( + link for link in transformed["links"] if link["rel"] == "data" + ) + assert "auth:refs" in collections_link + + items_link = next( + link for link in transformed["links"] if link["rel"] == "items" + ) + assert "auth:refs" in items_link + + def test_non_matching_link_not_annotated( + self, middleware_with_items_filter, request_scope + ): + """Links that don't match filter paths are not annotated when default_public=True.""" + request = Request(request_scope) + data = { + "stac_version": "1.0.0", + "type": "Collection", + "id": "test", + "description": "Test", + "links": [ + {"rel": "self", "href": "/collections/test"}, + {"rel": "root", "href": "/"}, + ], + } + + transformed = middleware_with_items_filter.transform_json(data, request) + + for link in transformed["links"]: + assert "auth:refs" not in link + + def test_no_filter_paths_configured(self, middleware, request_scope): + """Without filter paths, default_public=True means no links get auth:refs.""" + request = Request(request_scope) + data = { + "stac_version": "1.0.0", + "id": "test-catalog", + "description": "Test catalog", + "links": [ + {"rel": "self", "href": "/"}, + {"rel": "data", "href": "/collections"}, + {"rel": "items", "href": "/collections/test/items"}, + ], + } + + transformed = middleware.transform_json(data, request) + + for link in transformed["links"]: + assert "auth:refs" not in link + def test_link_method_used_for_matching(self, oidc_discovery_url, request_scope): """Link's method property is used when matching against private endpoints.""" middleware = AuthenticationExtensionMiddleware( @@ -284,3 +454,27 @@ def test_link_method_used_for_matching(self, oidc_discovery_url, request_scope): ) assert "auth:refs" in post_link assert "test_auth" in post_link["auth:refs"] + + def test_filter_path_with_absolute_url( + self, middleware_with_items_filter, request_scope + ): + """Filter path matching works with absolute URLs in link hrefs.""" + request = Request(request_scope) + data = { + "stac_version": "1.0.0", + "type": "Collection", + "id": "test", + "description": "Test", + "links": [ + { + "rel": "items", + "href": "https://example.com/collections/test/items", + }, + ], + } + + transformed = middleware_with_items_filter.transform_json(data, request) + + items_link = transformed["links"][0] + assert "auth:refs" in items_link + assert "test_auth" in items_link["auth:refs"]