Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def configure_app(
default_public=settings.default_public,
public_endpoints=settings.public_endpoints,
private_endpoints=settings.private_endpoints,
items_filter_path=(
settings.items_filter_path if settings.items_filter else None
),
collections_filter_path=(
settings.collections_filter_path
if settings.collections_filter
else None
),
oidc_discovery_url=str(settings.oidc_discovery_url),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Optional
from urllib.parse import urlparse

from starlette.datastructures import Headers
Expand Down Expand Up @@ -35,6 +35,9 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
)

items_filter_path: Optional[str] = None
collections_filter_path: Optional[str] = None

json_content_type_expr: str = r"application/(geo\+)?json"

def should_transform_response(self, request: Request, scope: Scope) -> bool:
Expand Down Expand Up @@ -94,8 +97,10 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
private_endpoints=self.private_endpoints,
public_endpoints=self.public_endpoints,
default_public=self.default_public,
items_filter_path=self.items_filter_path,
collections_filter_path=self.collections_filter_path,
)
if match.is_private:
if match.uses_auth:
auth_refs = ensure_type(link, "auth:refs", list)
auth_refs.append(self.auth_scheme_name)

Expand Down
2 changes: 1 addition & 1 deletion src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
try:
payload = self.validate_token(
request.headers.get("Authorization"),
auto_error=match.is_private,
auto_error=match.uses_auth,
required_scopes=match.required_scopes,
)

Expand Down
12 changes: 4 additions & 8 deletions src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class OpenApiMiddleware(JsonResponseMiddleware):
root_path: str = ""
auth_scheme_name: str = "oidcAuth"
auth_scheme_override: Optional[dict] = None

items_filter_path: Optional[str] = None
collections_filter_path: Optional[str] = None

Expand Down Expand Up @@ -79,15 +80,10 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
self.private_endpoints,
self.public_endpoints,
self.default_public,
items_filter_path=self.items_filter_path,
collections_filter_path=self.collections_filter_path,
)
if match.is_private or self._path_has_filter(path):
if match.uses_auth:
security = ensure_type(config, "security", list)
security.append({self.auth_scheme_name: match.required_scopes})
return data

def _path_has_filter(self, path: str) -> bool:
"""Check if a path matches any configured CQL2 filter path."""
for filter_path in (self.items_filter_path, self.collections_filter_path):
if filter_path and re.match(filter_path, path):
return True
return False
17 changes: 12 additions & 5 deletions src/stac_auth_proxy/utils/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,41 @@ def find_match(
private_endpoints: EndpointMethods,
public_endpoints: EndpointMethods,
default_public: bool,
items_filter_path: Optional[str] = None,
collections_filter_path: Optional[str] = None,
) -> "MatchResult":
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
primary_endpoints = private_endpoints if default_public else public_endpoints
matched, required_scopes = _check_endpoint_match(path, method, primary_endpoints)
if matched:
return MatchResult(
is_private=default_public,
uses_auth=default_public,
required_scopes=required_scopes,
)

# If we have filter paths configured, check those as well (these are always considered to use auth if they match, regardless of default_public)
for filter_path in [items_filter_path, collections_filter_path]:
if filter_path and re.match(filter_path, path):
return MatchResult(uses_auth=True)

# If default_public and no match found in private_endpoints, it's public
if default_public:
return MatchResult(is_private=False)
return MatchResult(uses_auth=False)

# If not default_public, check private_endpoints for required scopes
matched, required_scopes = _check_endpoint_match(path, method, private_endpoints)
if matched:
return MatchResult(is_private=True, required_scopes=required_scopes)
return MatchResult(uses_auth=True, required_scopes=required_scopes)

# Default case: if not default_public and no explicit match, it's private
return MatchResult(is_private=True)
return MatchResult(uses_auth=True)


@dataclass
class MatchResult:
"""Result of a match between a path and method and a set of endpoints."""

is_private: bool
uses_auth: bool
required_scopes: Sequence[str] = field(default_factory=list)


Expand Down
194 changes: 194 additions & 0 deletions tests/test_auth_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,176 @@ def test_transform_json_with_invalid_stac_extensions_types(
assert "auth:schemes" in transformed
assert "test_auth" in transformed["auth:schemes"]


class TestFilterPathAnnotation:
"""Tests for auth:refs annotation when links match items/collections filter paths."""

@pytest.fixture
def middleware_with_items_filter(self, oidc_discovery_url):
"""Middleware with items_filter_path configured."""
return AuthenticationExtensionMiddleware(
app=None,
default_public=True,
private_endpoints=EndpointMethods(),
public_endpoints=EndpointMethods(),
oidc_discovery_url=oidc_discovery_url,
auth_scheme_name="test_auth",
items_filter_path=r"^/collections/[^/]+/items$",
)

@pytest.fixture
def middleware_with_collections_filter(self, oidc_discovery_url):
"""Middleware with collections_filter_path configured."""
return AuthenticationExtensionMiddleware(
app=None,
default_public=True,
private_endpoints=EndpointMethods(),
public_endpoints=EndpointMethods(),
oidc_discovery_url=oidc_discovery_url,
auth_scheme_name="test_auth",
collections_filter_path=r"^/collections$",
)

@pytest.fixture
def middleware_with_both_filters(self, oidc_discovery_url):
"""Middleware with both filter paths configured."""
return AuthenticationExtensionMiddleware(
app=None,
default_public=True,
private_endpoints=EndpointMethods(),
public_endpoints=EndpointMethods(),
oidc_discovery_url=oidc_discovery_url,
auth_scheme_name="test_auth",
items_filter_path=r"^/collections/[^/]+/items$",
collections_filter_path=r"^/collections$",
)

def test_items_link_annotated_when_items_filter_matches(
self, middleware_with_items_filter, request_scope
):
"""Links to items endpoints get auth:refs when items_filter_path matches."""
request = Request(request_scope)
data = {
"stac_version": "1.0.0",
"type": "Collection",
"id": "test",
"description": "Test",
"links": [
{"rel": "self", "href": "/collections/test"},
{"rel": "items", "href": "/collections/test/items"},
],
}

transformed = middleware_with_items_filter.transform_json(data, request)

# The self link should NOT have auth:refs (no private endpoints, default_public=True)
self_link = next(link for link in transformed["links"] if link["rel"] == "self")
assert "auth:refs" not in self_link

# The items link SHOULD have auth:refs because it matches items_filter_path
items_link = next(
link for link in transformed["links"] if link["rel"] == "items"
)
assert "auth:refs" in items_link
assert "test_auth" in items_link["auth:refs"]

def test_collections_link_annotated_when_collections_filter_matches(
self, middleware_with_collections_filter, request_scope
):
"""Links to collections endpoint get auth:refs when collections_filter_path matches."""
request = Request(request_scope)
data = {
"stac_version": "1.0.0",
"id": "test-catalog",
"description": "Test catalog",
"links": [
{"rel": "self", "href": "/"},
{"rel": "data", "href": "/collections"},
],
}

transformed = middleware_with_collections_filter.transform_json(data, request)

self_link = next(link for link in transformed["links"] if link["rel"] == "self")
assert "auth:refs" not in self_link

collections_link = next(
link for link in transformed["links"] if link["rel"] == "data"
)
assert "auth:refs" in collections_link
assert "test_auth" in collections_link["auth:refs"]

def test_both_filter_paths_annotated(
self, middleware_with_both_filters, request_scope
):
"""Links matching either filter path get auth:refs."""
request = Request(request_scope)
data = {
"stac_version": "1.0.0",
"id": "test-catalog",
"description": "Test catalog",
"links": [
{"rel": "self", "href": "/"},
{"rel": "data", "href": "/collections"},
{"rel": "items", "href": "/collections/test/items"},
],
}

transformed = middleware_with_both_filters.transform_json(data, request)

self_link = next(link for link in transformed["links"] if link["rel"] == "self")
assert "auth:refs" not in self_link

collections_link = next(
link for link in transformed["links"] if link["rel"] == "data"
)
assert "auth:refs" in collections_link

items_link = next(
link for link in transformed["links"] if link["rel"] == "items"
)
assert "auth:refs" in items_link

def test_non_matching_link_not_annotated(
self, middleware_with_items_filter, request_scope
):
"""Links that don't match filter paths are not annotated when default_public=True."""
request = Request(request_scope)
data = {
"stac_version": "1.0.0",
"type": "Collection",
"id": "test",
"description": "Test",
"links": [
{"rel": "self", "href": "/collections/test"},
{"rel": "root", "href": "/"},
],
}

transformed = middleware_with_items_filter.transform_json(data, request)

for link in transformed["links"]:
assert "auth:refs" not in link

def test_no_filter_paths_configured(self, middleware, request_scope):
"""Without filter paths, default_public=True means no links get auth:refs."""
request = Request(request_scope)
data = {
"stac_version": "1.0.0",
"id": "test-catalog",
"description": "Test catalog",
"links": [
{"rel": "self", "href": "/"},
{"rel": "data", "href": "/collections"},
{"rel": "items", "href": "/collections/test/items"},
],
}

transformed = middleware.transform_json(data, request)

for link in transformed["links"]:
assert "auth:refs" not in link

def test_link_method_used_for_matching(self, oidc_discovery_url, request_scope):
"""Link's method property is used when matching against private endpoints."""
middleware = AuthenticationExtensionMiddleware(
Expand Down Expand Up @@ -284,3 +454,27 @@ def test_link_method_used_for_matching(self, oidc_discovery_url, request_scope):
)
assert "auth:refs" in post_link
assert "test_auth" in post_link["auth:refs"]

def test_filter_path_with_absolute_url(
self, middleware_with_items_filter, request_scope
):
"""Filter path matching works with absolute URLs in link hrefs."""
request = Request(request_scope)
data = {
"stac_version": "1.0.0",
"type": "Collection",
"id": "test",
"description": "Test",
"links": [
{
"rel": "items",
"href": "https://example.com/collections/test/items",
},
],
}

transformed = middleware_with_items_filter.transform_json(data, request)

items_link = transformed["links"][0]
assert "auth:refs" in items_link
assert "test_auth" in items_link["auth:refs"]
Loading