diff --git a/.github/workflows/data/noaa-emergency-response.json b/.github/workflows/data/noaa-emergency-response.json index dd9aabd4..5850e5ee 100644 --- a/.github/workflows/data/noaa-emergency-response.json +++ b/.github/workflows/data/noaa-emergency-response.json @@ -3,6 +3,9 @@ "title": "NOAA Emergency Response Imagery", "description": "NOAA Emergency Response Imagery hosted on AWS Public Dataset.", "stac_version": "1.0.0", + "properties": { + "tenant": "test-tenant" + }, "license": "public-domain", "links": [], "extent": { diff --git a/stac_api/runtime/src/app.py b/stac_api/runtime/src/app.py index 6a65cf74..708908ce 100644 --- a/stac_api/runtime/src/app.py +++ b/stac_api/runtime/src/app.py @@ -1,4 +1,4 @@ -"""FastAPI application using PGStac. +"""FastAPI application using PGStac with integrated tenant filtering. Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac """ @@ -27,8 +27,10 @@ from starlette.templating import Jinja2Templates from starlette_cramjam.middleware import CompressionMiddleware -from .core import VedaCrudClient from .monitoring import LoggerRouteHandler, logger, metrics, tracer +from .tenant_client import TenantAwareVedaCrudClient +from .tenant_middleware import TenantMiddleware +from .tenant_routes import create_tenant_router from .validation import ValidationMiddleware from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings @@ -54,6 +56,8 @@ async def lifespan(app: FastAPI): await close_db_connection(app) +tenant_client = TenantAwareVedaCrudClient(pgstac_search_model=post_request_model) + api = StacApi( app=FastAPI( title=f"{api_settings.project_name} STAC API", @@ -76,17 +80,28 @@ async def lifespan(app: FastAPI): description=api_settings.project_description, settings=api_settings, extensions=application_extensions, - client=VedaCrudClient(pgstac_search_model=post_request_model), + client=tenant_client, search_get_request_model=get_request_model, search_post_request_model=post_request_model, collections_get_request_model=collections_get_request_model, items_get_request_model=items_get_request_model, response_class=ORJSONResponse, - middlewares=[Middleware(CompressionMiddleware), Middleware(ValidationMiddleware)], + middlewares=[ + Middleware(CompressionMiddleware), + Middleware(ValidationMiddleware), + Middleware(TenantMiddleware), + ], router=APIRouter(route_class=LoggerRouteHandler), ) app = api.app +# Add tenant-specific routes +logger.info("Creating tenant router...") +tenant_router = create_tenant_router(tenant_client) +logger.info(f"Registering tenant router with {len(tenant_router.routes)} routes") +app.include_router(tenant_router, tags=["Tenant-specific endpoints"]) +logger.info("Tenant router registered successfully") + # Set all CORS enabled origins if api_settings.cors_origins: app.add_middleware( @@ -143,6 +158,20 @@ async def viewer_page(request: Request): ) +@app.get("/{tenant}/index.html", response_class=HTMLResponse) +async def tenant_viewer_page(request: Request, tenant: str): + """Tenant-specific search viewer.""" + return templates.TemplateResponse( + "stac-viewer.html", + { + "request": request, + "endpoint": str(request.url).replace("/index.html", f"/{tenant}"), + "tenant": tenant, + }, + media_type="text/html", + ) + + # If the correlation header is used in the UI, we can analyze traces that originate from a given user or client @app.middleware("http") async def add_correlation_id(request: Request, call_next): diff --git a/stac_api/runtime/src/tenant_client.py b/stac_api/runtime/src/tenant_client.py new file mode 100644 index 00000000..19a7fff0 --- /dev/null +++ b/stac_api/runtime/src/tenant_client.py @@ -0,0 +1,334 @@ +""" +Tenant Client for Tenant Middleware +""" +import logging +from typing import Any, Dict, Optional, Union +from urllib.parse import urlparse + +from fastapi import HTTPException +from fastapi import Request as FastAPIRequest +from stac_fastapi.types.stac import Collection, Item, ItemCollection, LandingPage +from starlette.requests import Request + +from .core import VedaCrudClient +from .tenant_models import TenantValidationError + +logger = logging.getLogger(__name__) + + +class TenantValidationMixin: + """Tenant Validation Mixin""" + + def validate_tenant_access( + self, + resource: Union[Dict[str, Any], Collection], + tenant: str, + resource_id: str = "", + ) -> None: + """Validate that a collection resource belongs to a tenant""" + resource_tenant = self._extract_tenant_from_resource(resource) + + if resource_tenant != tenant: + raise TenantValidationError( + resource_type="Collection" if "collection" in resource else "Item", + resource_id=resource_id, + tenant=tenant, + actual_tenant=resource_tenant, + ) + + def _extract_tenant_from_resource( + self, resource: Union[Dict[str, Any], Collection] + ) -> Optional[str]: + return resource.get("properties", {}).get("tenant") + + +class TenantAwareVedaCrudClient(VedaCrudClient, TenantValidationMixin): + """Tenant Aware VEDA Crud Client""" + + def __init__(self, *args, **kwargs): + """Initializes tenant-aware VEDA CRUD client by extending + the base VEDA CRUD client with tenant functionality such as filtering, + validation, and customized landing page links. + + Args: + *args: positional args passed to parent VedaCrudClient + **kwargs: keyword args passed to parent VedaCrudClient such as + pgstac_search_model + + """ + super().__init__(*args, **kwargs) + + def get_tenant_from_request(self, request: Request) -> Optional[str]: + """Gets tenant string from request + + Args: + request: Incoming request + + Returns: + tenant, if there is one. None otherwise. + + """ + if hasattr(request, "path_params") and "tenant" in request.path_params: + return request.path_params["tenant"] + return None + + async def get_tenant_collections( + self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs + ) -> Dict[str, Any]: + """Gets collections belonging to a tenant + + Args: + request: Incoming request + tenant: Tenant ID + + Returns: + Collections belonging to tenant + + """ + collections = await super().all_collections(request, **kwargs) + + collections_dict = collections + + if ( + tenant + and isinstance(collections_dict, dict) + and "collections" in collections_dict + ): + filtered_collections = [ + col + for col in collections_dict["collections"] + if col.get("properties", {}).get("tenant") == tenant + ] + collections_dict["collections"] = filtered_collections + if "numberReturned" in collections_dict: + collections_dict["numberReturned"] = len(filtered_collections) + + return collections_dict + + async def get_collection( + self, + collection_id: str, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs, + ) -> Collection: + """Get a specific collection belonging to a tenant by collection, tenant IDs""" + + collection = await super().get_collection(collection_id, request, **kwargs) + + if tenant and collection: + self.validate_tenant_access(collection, tenant, collection_id) + + return collection + + async def item_collection( + self, + collection_id: str, + request: FastAPIRequest, + tenant: Optional[str] = None, + limit: int = 10, + token: Optional[str] = None, + **kwargs, + ) -> ItemCollection: + """Get all items from collection using collection ID and tenant ID""" + if tenant: + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}", + ) + self.validate_tenant_access(collection, tenant, collection_id) + + return await super().item_collection( + collection_id=collection_id, + request=request, + limit=limit, + token=token, + **kwargs, + ) + + async def get_item( + self, + item_id: str, + collection_id: str, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs, + ) -> Item: + """Get specific item from collection using collection ID and tenant ID""" + if tenant: + collection = await super().get_collection(collection_id, request, **kwargs) + if not collection: + raise HTTPException( + status_code=404, + detail=f"Collection {collection_id} not found for tenant {tenant}", + ) + self.validate_tenant_access(collection, tenant, collection_id) + + return await super().get_item(item_id, collection_id, request, **kwargs) + + async def post_search( + self, + search_request, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs, + ) -> ItemCollection: + """POST Search request with tenant filtering + + Args: + search_request: the search request parameters + request: the FastAPI request object + tenant: optional tenant identifier for filtering search + **kwargs: additional arguments to pass to the parent method + + Returns: + ItemCollection of the filtered search results + + """ + result = await super().post_search(search_request, request, **kwargs) + + if tenant: + result = self._filter_search_results_by_tenant(result, tenant) + + return result + + async def get_search( + self, + request: FastAPIRequest, + tenant: Optional[str] = None, + **kwargs, + ) -> ItemCollection: + """GET Search request with tenant filtering + + Args: + search_request: the search request parameters + request: the FastAPI request object + tenant: optional tenant identifier for filtering search + **kwargs: additional arguments to pass to the parent method + + Returns: + ItemCollection of the filtered search results + + """ + result = await super().get_search(request, **kwargs) + + if tenant: + result = self._filter_search_results_by_tenant(result, tenant) + + return result + + def _filter_search_results_by_tenant( + self, result: ItemCollection, tenant: str + ) -> ItemCollection: + """Internal function to filter search results by tenant + + Args: + result: ItemCollection to filter + tenant: Tenant identifier to filter on + + Returns: + Filtered ItemCollection + """ + if isinstance(result, dict) and "features" in result: + filtered_features = [ + feature + for feature in result["features"] + if feature.get("properties", {}).get("tenant") == tenant + ] + result["features"] = filtered_features + if "numberReturned" in result: + result["numberReturned"] = len(filtered_features) + + return result + + async def landing_page( + self, request: FastAPIRequest, tenant: Optional[str] = None, **kwargs + ) -> LandingPage: + """Get or generate landing page if a tenant is provided + + Args: + request: Fast API request object + tenant: Optional tenant identifier + **kwargs: Optional key word args to pass to parent method + + Returns: + Landing Page, customized if tenant provided + """ + tenant_context = getattr(request.state, "tenant_context", None) + + logger.info( + f"Landing page requested for tenant: {tenant}", + extra={ + "tenant_id": tenant, + "request_id": tenant_context.request_id if tenant_context else None, + "endpoint": "landing_page", + }, + ) + + landing_page = await super().landing_page(request=request, **kwargs) + + if tenant: + landing_page = self._customize_landing_page_for_tenant(landing_page, tenant) + + logger.info( + f"Landing page customized for tenant: {tenant}", + extra={ + "tenant_id": tenant, + "request_id": tenant_context.request_id if tenant_context else None, + "links_modified": len(landing_page.get("links", [])), + }, + ) + + return landing_page + + def _customize_landing_page_for_tenant( + self, landing_page: LandingPage, tenant: str + ) -> LandingPage: + """ + Customized landing page with tenant route path injected into url + """ + + if "title" in landing_page: + landing_page["title"] = f"{tenant.upper()} - {landing_page['title']}" + + if "links" in landing_page: + for link in landing_page["links"]: + logger.info("Inspecting links to inject tenant...") + if "href" in link: + href = link["href"] + rel = link.get("rel") + + skip_rels = [ + "self", + "root", + "service-desc", + "service-doc", + "conformance", + ] + if rel in skip_rels or "queryables" in rel: + continue + + if href.startswith("http"): + parsed = urlparse(href) + path_parts = parsed.path.split("/") + # a URL should follow this structure scheme://netloc/path;parameters?query#fragment generally + # source: https://docs.python.org/3/library/urllib.parse.html + if ( + len(path_parts) >= 3 + and path_parts[1] == "api" + and path_parts[2] == "stac" + ): + new_path_parts = path_parts[:3] + [tenant] + path_parts[3:] + new_path = "/".join(new_path_parts) + link[ + "href" + ] = f"{parsed.scheme}://{parsed.netloc}{new_path}" + else: + if href.startswith("/api/stac"): + link["href"] = href.replace( + "/api/stac", f"/api/stac/{tenant}" + ) + + return landing_page diff --git a/stac_api/runtime/src/tenant_middleware.py b/stac_api/runtime/src/tenant_middleware.py new file mode 100644 index 00000000..937ee246 --- /dev/null +++ b/stac_api/runtime/src/tenant_middleware.py @@ -0,0 +1,151 @@ +""" Tenant Middleware for STAC API. Useful for extracting tenant information """ +import logging +from typing import Optional + +from fastapi import HTTPException, Request +from starlette.middleware.base import BaseHTTPMiddleware + +from .tenant_models import TenantContext, TenantValidationError + +logger = logging.getLogger(__name__) + + +class TenantMiddleware(BaseHTTPMiddleware): + """Middleware for tenant-aware STAC API request processing. + + This middleware extracts the tenant identifier from the URL path and creates a context + for downstream processing. It also handles valiadtion errors. + + It will process requests by: + - extracting the tenant from the URL path (/api/stac/{tenant}/...) + - creating a TenantContext with the tenant ID and correlation ID + - handle validation errors + + """ + + def __init__(self, app): + """Initializes the tenant middleware""" + super().__init__(app) + + async def dispatch(self, request: Request, call_next): + """Processes incoming requests and extracts the tenant identifier from the URL""" + + try: + if self._should_skip_tenant_processing(request): + return await call_next(request) + + tenant = self._extract_tenant(request) + logger.info(f"Extracted tenant is {tenant}") + + tenant_context = ( + TenantContext( + tenant_id=tenant, + request_id=request.headers.get("X-Correlation-ID"), + ) + if tenant + else None + ) + + request.state.tenant_context = tenant_context + + if tenant: + logger.info( + f"Tenant access: {tenant} for {request.method} {request.url.path}", + extra={ + "tenant": tenant, + "method": request.method, + "path": request.url.path if tenant_context else None, + }, + ) + + response = await call_next(request) + + if tenant_context: + response.headers["X-Tenant-ID"] = tenant_context.tenant_id + if tenant_context.request_id: + response.headers["X-Request-ID"] = tenant_context.request_id + + return response + + except TenantValidationError as e: + logger.warning( + f"Tenant validation failed: {e.detail}", + extra={ + "tenant": getattr(e, "tenant", None), + "resource_type": getattr(e, "resource_type", None), + "resource_id": getattr(e, "resource_id", None), + }, + ) + raise HTTPException(status_code=404, detail=e.detail) + + except Exception as e: + logger.error(f"Tenant middleware error: {str(e)}") + raise + + def _should_skip_tenant_processing(self, request: Request) -> bool: + """Check if tenant processing should be skipped for this request""" + path = request.url.path + logger.info(f"Tenant middleware processing path: {path}") + + # handles both local (no prefix) and production (/api/stac/ prefix) environments + if path.startswith("/api/stac/"): + if path == "/api/stac/" or path == "/api/stac": + logger.info(f"Skipping tenant processing - root STAC API: {path}") + return True + + path_parts = path.replace("/api/stac/", "").split("/") + else: + path_parts = path.lstrip("/").split("/") + + logger.info(f"Path parts: {path_parts}") + + if not path_parts or not path_parts[0]: + logger.info(f"Skipping tenant processing - empty path parts: {path}") + return True + + standard_endpoints = { + "collections", + "conformance", + "search", + "queryables", + "openapi.json", + "docs", + "favicon.ico", + "health", + "ping", + } + + first_part = path_parts[0].rstrip("/") + logger.info( + f"First part: '{path_parts[0]}', stripped: '{first_part}', in standard_endpoints: {first_part in standard_endpoints}" + ) + + # if the path is exactly a standard endpoint with trailing slash, skip tenant processing + if ( + len(path_parts) == 2 + and path_parts[1] == "" + and first_part in standard_endpoints + ): + logger.info( + f"Skipping tenant processing for standard endpoint with trailing slash: {first_part}/" + ) + return True + + if first_part in standard_endpoints: + logger.info( + f"Skipping tenant processing for standard endpoint: {first_part}" + ) + return True + + logger.info(f"Processing as tenant: {first_part}") + return False + + def _extract_tenant(self, request: Request) -> Optional[str]: + """Extracts the tenant identifier from the URL""" + path = request.url.path + logger.info(f"Extracting tenant from request path {path}") + if path.startswith("/api/stac/"): + path_parts = path.replace("/api/stac/", "").split("/") + if path_parts and path_parts[0]: + return path_parts[0] + return None diff --git a/stac_api/runtime/src/tenant_models.py b/stac_api/runtime/src/tenant_models.py new file mode 100644 index 00000000..43570455 --- /dev/null +++ b/stac_api/runtime/src/tenant_models.py @@ -0,0 +1,76 @@ +""" Tenant Models for STAC API """ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field, field_validator + +from fastapi import HTTPException + + +class TenantContext(BaseModel): + """Context information for tenant-aware request processing""" + + tenant_id: str = Field(..., description="Tenant identifier") + request_id: Optional[str] = Field(None, description="Request correlation ID") + + @field_validator("tenant_id") + @classmethod + def validate_tenant_id(cls, v): + """Validates the tenant ID and also normalizes it to lowercase and trims the whitespace""" + if not v or not v.strip(): + raise ValueError("Tenant ID cannot be empty") + if len(v) > 100: + raise ValueError("Tenant ID too long") + return v.strip().lower() + + +class TenantSearchRequest(BaseModel): + """Tenant-aware search request model""" + + tenant: Optional[str] = Field(None, description="Tenant identifier") + collections: Optional[List[str]] = Field( + None, description="Collection IDs to search" + ) + bbox: Optional[List[float]] = Field(None, description="Bounding box") + datetime: Optional[str] = Field(None, description="Datetime range") + limit: int = Field(10, description="Maximum number of results") + token: Optional[str] = Field(None, description="Pagination token") + filter: Optional[Dict[str, Any]] = Field(None, description="CQL2 filter") + filter_lang: str = Field("cql2-text", description="Filter language") + conf: Optional[Dict] = None + + def add_tenant_filter(self, tenant: str) -> None: + """Add tenant filter to the search request""" + if not tenant: + return + + # Create tenant filter for properties.tenant + tenant_filter = {"op": "=", "args": [{"property": "tenant"}, tenant]} + + # If there's already a filter, combine using AND + if self.filter: + self.filter = {"op": "and", "args": [self.filter, tenant_filter]} + else: + self.filter = tenant_filter + + +class TenantValidationError(HTTPException): + """Exception that can be used to raise tenant validation failures""" + + def __init__( + self, + resource_type: str, + resource_id: str, + tenant: str, + actual_tenant: Optional[str] = None, + ): + """Initiailizes tenant validation error""" + self.resource_type = resource_type + self.resource_id = resource_id + self.tenant = tenant + self.actual_tenant = actual_tenant + + detail = f"{resource_type} {resource_id} not found for tenant {tenant}" + if actual_tenant: + detail += f" (found tenant: {actual_tenant})" + + super().__init__(status_code=404, detail=detail) diff --git a/stac_api/runtime/src/tenant_routes.py b/stac_api/runtime/src/tenant_routes.py new file mode 100644 index 00000000..07b4a79f --- /dev/null +++ b/stac_api/runtime/src/tenant_routes.py @@ -0,0 +1,252 @@ +""" Tenant Route Handler """ +import json +import logging +from typing import Any, Dict, Optional + +from fastapi import APIRouter, HTTPException, Path, Query, Request +from stac_fastapi.types.stac import Item, ItemCollection + +from .tenant_client import TenantAwareVedaCrudClient +from .tenant_models import TenantSearchRequest + +logger = logging.getLogger(__name__) + + +class TenantRouteHandler: + """Route handler for tenant-aware STAC API endpoints""" + + def __init__(self, client: TenantAwareVedaCrudClient): + """Initializes tenant-aware route handler""" + self.client = client + + async def get_tenant_collections( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + ) -> Dict[str, Any]: + """Get all collections belonging to a tenant""" + logger.info(f"Getting collections for tenant: {tenant}") + + try: + collections = await self.client.get_tenant_collections( + request, tenant=tenant + ) + return collections + except Exception as e: + logger.error(f"Error getting collections for tenant {tenant}: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + async def get_tenant_collection( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + ) -> Dict: + """Get a specific collection belonging to a specific tenant""" + logger.info(f"Getting collection {collection_id} for tenant: {tenant}") + + try: + collection = await self.client.get_collection( + collection_id, request, tenant=tenant + ) + return collection + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error getting collection {collection_id} for tenant {tenant}: {str(e)}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + async def get_tenant_collection_items( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + limit: int = Query(10, description="Maximum number of items to return"), + token: Optional[str] = Query(None, description="Pagination token"), + ) -> ItemCollection: + """Get all items from a collection filtered by a specific tenant""" + logger.info( + f"Getting items from collection {collection_id} for tenant: {tenant}" + ) + + try: + items = await self.client.item_collection( + collection_id=collection_id, + request=request, + tenant=tenant, + limit=limit, + token=token, + ) + return items + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error getting items from collection {collection_id} for tenant {tenant}: {str(e)}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + async def get_tenant_item( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + collection_id: str = Path(..., description="Collection identifier"), + item_id: str = Path(..., description="Item identifier"), + ) -> Item: + """Get a specific item for a tenant""" + logger.info( + f"Getting item {item_id} from collection {collection_id} for tenant: {tenant}" + ) + + try: + item = await self.client.get_item( + item_id, collection_id, request, tenant=tenant + ) + return item + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error getting item {item_id} from collection {collection_id} for tenant {tenant}: {str(e)}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + async def get_tenant_search( + self, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + collections: Optional[str] = Query( + None, description="Comma-separated list of collection IDs" + ), + ids: Optional[str] = Query( + None, description="Comma-separated list of item IDs" + ), + bbox: Optional[str] = Query(None, description="Bounding box"), + datetime: Optional[str] = Query(None, description="Datetime range"), + limit: int = Query(10, description="Maximum number of results"), + query: Optional[str] = Query(None, description="Query parameters"), + token: Optional[str] = Query(None, description="Pagination token"), + filter_lang: Optional[str] = Query("cql2-text", description="Filter language"), + filter: Optional[str] = Query(None, description="CQL2 filter"), + sortby: Optional[str] = Query(None, description="Sort parameters"), + ) -> ItemCollection: + """Search items for a specific tenant using GET""" + logger.info(f"GET search for tenant: {tenant}") + + try: + search_params = { + "collections": collections.split(",") if collections else None, + "ids": ids.split(",") if ids else None, + "bbox": [float(x) for x in bbox.split(",")] if bbox else None, + "datetime": datetime, + "limit": limit, + "query": json.loads(query) if query else None, + "token": token, + "filter-lang": filter_lang, + "filter": json.loads(filter) if filter else None, + "sortby": sortby, + } + + clean_params = {k: v for k, v in search_params.items() if v is not None} + + search_result = await self.client.get_search( + request, tenant=tenant, **clean_params + ) + + return search_result + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON in search parameters: {str(e)}") + raise HTTPException( + status_code=400, detail="Invalid JSON in search parameters" + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error performing search for tenant {tenant}: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + async def post_tenant_search( + self, + search_request: TenantSearchRequest, + request: Request, + tenant: str = Path(..., description="Tenant identifier"), + ) -> ItemCollection: + """Search items for a specific tenant using POST""" + logger.info(f"POST search for tenant: {tenant}") + + try: + search_request.add_tenant_filter(tenant) + + search_result = await self.client.post_search( + search_request, request, tenant=tenant + ) + + return search_result + except HTTPException: + raise + except Exception as e: + logger.error(f"Error performing POST search for tenant {tenant}: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + +def create_tenant_router(client: TenantAwareVedaCrudClient) -> APIRouter: + """Create tenant-specific router""" + + router = APIRouter(redirect_slashes=True) + handler = TenantRouteHandler(client) + + logger.info("Creating tenant router with routes") + + router.add_api_route( + "/{tenant}/collections", + handler.get_tenant_collections, + methods=["GET"], + summary="Get collections for tenant", + description="Retrieve all collections for a specific tenant", + ) + + router.add_api_route( + "/{tenant}/collections/{collection_id}", + handler.get_tenant_collection, + methods=["GET"], + summary="Get tenant collection", + description="Retrieve a specific collection for a tenant", + ) + + router.add_api_route( + "/{tenant}/collections/{collection_id}/items", + handler.get_tenant_collection_items, + methods=["GET"], + summary="Get tenant collection items", + description="Retrieve items from a collection for a tenant", + ) + + router.add_api_route( + "/{tenant}/collections/{collection_id}/items/{item_id}", + handler.get_tenant_item, + methods=["GET"], + summary="Get tenant item", + description="Retrieve a specific item for a tenant", + ) + + # Search endpoints + router.add_api_route( + "/{tenant}/search", + handler.get_tenant_search, + methods=["GET"], + summary="Search tenant items (GET)", + description="Search items for a tenant using GET method", + ) + + router.add_api_route( + "/{tenant}/search", + handler.post_tenant_search, + methods=["POST"], + summary="Search tenant items (POST)", + description="Search items for a tenant using POST method", + ) + + logger.info(f"Created tenant router with {len(router.routes)} routes") + return router diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py index ede8bad1..4b0463e0 100644 --- a/stac_api/runtime/tests/conftest.py +++ b/stac_api/runtime/tests/conftest.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch import pytest +import pytest_asyncio from httpx import ASGITransport, AsyncClient from stac_fastapi.pgstac.db import close_db_connection, connect_to_db @@ -75,6 +76,51 @@ }, } +VALID_COLLECTION_WITH_TENANT = { + "id": "campfire-lst-day-diff", + "type": "Collection", + "links": [], + "title": "Camp Fire Domain: MODIS LST Day Difference", + "extent": { + "spatial": { + "bbox": [ + [ + -121.78460307847297, + 39.59483467430542, + -121.35341172149457, + 39.89994756059251, + ] + ] + }, + "temporal": { + "interval": [["2015-01-01T00:00:00+00:00", "2022-01-01T00:00:00+00:00"]] + }, + }, + "license": "CC0-1.0", + "providers": [ + { + "url": "https://www.earthdata.nasa.gov/dashboard/", + "name": "NASA VEDA", + "roles": ["host"], + } + ], + "summaries": {"datetime": ["2015-01-01T00:00:00Z"]}, + "properties": {"tenant": "fake-tenant"}, + "description": "MODIS WSA Albedo difference from a three-year average of 2015 to 2018 subtracted from a three-year average of 2019-2022. These tri-annual averages represent periods before and after the fire.", + "item_assets": { + "cog_default": { + "type": "image/tiff; application=geotiff; profile=cloud-optimized", + "roles": ["data", "layer"], + "title": "Default COG Layer", + "description": "Cloud optimized default layer to display on map", + } + }, + "stac_version": "1.0.0", + "stac_extensions": [ + "https://stac-extensions.github.io/item-assets/v1.0.0/schema.json" + ], +} + VALID_ITEM = { "id": "OMI_trno2_0.10x0.10_2023_Col3_V4", "bbox": [-180.0, -90.0, 180.0, 90.0], @@ -152,38 +198,6 @@ ] ], }, - "proj:projjson": { - "id": {"code": 4326, "authority": "EPSG"}, - "name": "WGS 84", - "type": "GeographicCRS", - "datum": { - "name": "World Geodetic System 1984", - "type": "GeodeticReferenceFrame", - "ellipsoid": { - "name": "WGS 84", - "semi_major_axis": 6378137, - "inverse_flattening": 298.257223563, - }, - }, - "$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", - "coordinate_system": { - "axis": [ - { - "name": "Geodetic latitude", - "unit": "degree", - "direction": "north", - "abbreviation": "Lat", - }, - { - "name": "Geodetic longitude", - "unit": "degree", - "direction": "east", - "abbreviation": "Lon", - }, - ], - "subtype": "ellipsoidal", - }, - }, "proj:transform": [0.1, 0.0, -180.0, 0.0, -0.1, 90.0, 0.0, 0.0, 1.0], }, "rendered_preview": { @@ -273,7 +287,7 @@ def mock_auth(): yield mock_instance -@pytest.fixture +@pytest_asyncio.fixture async def app(): """ Fixture to initialize the FastAPI application. @@ -294,7 +308,7 @@ async def app(): await close_db_connection(app) -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def api_client(app): """ Fixture to initialize the API client for making requests. @@ -334,6 +348,17 @@ def valid_stac_collection(): return VALID_COLLECTION +@pytest.fixture +def valid_stac_collection_with_tenant(): + """ + Fixture providing a valid STAC collection with tenant for testing. + + Returns: + dict: A valid STAC collection with tenant. + """ + return VALID_COLLECTION_WITH_TENANT + + @pytest.fixture def invalid_stac_collection(): """ @@ -371,8 +396,10 @@ def invalid_stac_item(): return invalid_item -@pytest.fixture -async def collection_in_db(api_client, valid_stac_collection): +@pytest_asyncio.fixture +async def collection_in_db( + api_client, valid_stac_collection, valid_stac_collection_with_tenant +): """ Fixture to ensure a valid STAC collection exists in the database. @@ -380,11 +407,20 @@ async def collection_in_db(api_client, valid_stac_collection): the collection ID. """ # Create the collection - response = await api_client.post("/collections", json=valid_stac_collection) + collection_response = await api_client.post( + "/collections", json=valid_stac_collection + ) + collection_with_tenant_response = await api_client.post( + "/collections", json=valid_stac_collection_with_tenant + ) # Ensure the setup was successful before the test proceeds # The setup is successful if the collection was created (201) or if it # already existed (409). Any other status code is a failure. - assert response.status_code in [201, 409] + assert collection_response.status_code in [201, 409] + assert collection_with_tenant_response.status_code in [201, 409] - yield valid_stac_collection["id"] + yield { + "regular_collection": valid_stac_collection["id"], + "tenant_collection": valid_stac_collection_with_tenant["id"], + } diff --git a/stac_api/runtime/tests/test_extensions.py b/stac_api/runtime/tests/test_extensions.py index 8ff18e81..0aa68ff2 100644 --- a/stac_api/runtime/tests/test_extensions.py +++ b/stac_api/runtime/tests/test_extensions.py @@ -17,10 +17,12 @@ """ +import pytest collections_endpoint = "/collections" items_endpoint = "/collections/{}/items" bulk_endpoint = "/collections/{}/bulk_items" +tenant_collections_endpoint = "/fake-tenant/collections" class TestList: @@ -33,6 +35,7 @@ class TestList: necessary data. """ + @pytest.mark.asyncio async def test_post_invalid_collection(self, api_client, invalid_stac_collection): """ Test the API's response to posting an invalid STAC collection. @@ -46,6 +49,7 @@ async def test_post_invalid_collection(self, api_client, invalid_stac_collection assert response.json()["detail"] == "Validation Error" assert response.status_code == 422 + @pytest.mark.asyncio async def test_post_valid_collection(self, api_client, valid_stac_collection): """ Test the API's response to posting a valid STAC collection. @@ -55,8 +59,9 @@ async def test_post_valid_collection(self, api_client, valid_stac_collection): response = await api_client.post( collections_endpoint, json=valid_stac_collection ) - assert response.status_code == 201 + assert response.status_code in [201, 409] + @pytest.mark.asyncio async def test_post_invalid_item(self, api_client, invalid_stac_item): """ Test the API's response to posting an invalid STAC item. @@ -71,18 +76,23 @@ async def test_post_invalid_item(self, api_client, invalid_stac_item): assert response.json()["detail"] == "Validation Error" assert response.status_code == 422 + @pytest.mark.asyncio async def test_post_valid_item(self, api_client, valid_stac_item, collection_in_db): """ Test the API's response to posting a valid STAC item. Asserts that the response status code is 200. """ - collection_id = valid_stac_item["collection"] + collection_id = collection_in_db["regular_collection"] + item_data = valid_stac_item.copy() + item_data["collection"] = collection_id + response = await api_client.post( items_endpoint.format(collection_id), json=valid_stac_item ) - assert response.status_code == 201 + assert response.status_code in [201, 409] # 201 for new, 409 for existing + @pytest.mark.asyncio async def test_post_invalid_bulk_items(self, api_client, invalid_stac_item): """ Test the API's response to posting invalid bulk STAC items. @@ -98,6 +108,7 @@ async def test_post_invalid_bulk_items(self, api_client, invalid_stac_item): ) assert response.status_code == 422 + @pytest.mark.asyncio async def test_post_valid_bulk_items( self, api_client, valid_stac_item, collection_in_db ): @@ -115,24 +126,23 @@ async def test_post_valid_bulk_items( ) assert response.status_code == 200 + @pytest.mark.asyncio async def test_get_collection_by_id(self, api_client, collection_in_db): """ Test searching for a specific collection by its ID. """ # The `collection_in_db` fixture ensures the collection exists and provides its ID. - collection_id = collection_in_db + collection_id = collection_in_db["regular_collection"] - # Perform a GET request to the /collections endpoint with an "ids" query - response = await api_client.get( - collections_endpoint, params={"ids": collection_id} - ) + response = await api_client.get(f"{collections_endpoint}/{collection_id}") assert response.status_code == 200 response_data = response.json() - assert response_data["collections"][0]["id"] == collection_id + assert response_data["id"] == collection_id + @pytest.mark.asyncio async def test_collection_freetext_search_by_title( self, api_client, collection_in_db ): @@ -141,7 +151,7 @@ async def test_collection_freetext_search_by_title( """ # The `collection_in_db` fixture ensures the collection exists. - collection_id = collection_in_db + collection_id = collection_in_db["regular_collection"] # Use a unique word from the collection's title for the query. search_term = "precipitation" @@ -156,3 +166,75 @@ async def test_collection_freetext_search_by_title( returned_ids = [col["id"] for col in response_data["collections"]] assert collection_id in returned_ids + + @pytest.mark.asyncio + async def test_get_collections_by_tenant(self, api_client, collection_in_db): + """ + Test searching for a specific collection by its ID. + """ + collection_id = collection_in_db["tenant_collection"] + + # Perform a GET request to the /collections endpoint with a tenant + response = await api_client.get( + tenant_collections_endpoint, + ) + + assert response.status_code == 200 + + response_data = response.json() + + assert response_data["collections"][0]["id"] == collection_id + + @pytest.mark.asyncio + async def test_get_collection_by_id_with_tenant(self, api_client, collection_in_db): + """ + Test searching for a specific collection by its ID and tenant + """ + # The `collection_in_db` fixture ensures the collection exists and provides its ID. + collection_id = collection_in_db["tenant_collection"] + + # Perform a GET request to the /fake-tenant/collections endpoint with an "ids" query + response = await api_client.get( + tenant_collections_endpoint, params={"ids": collection_id} + ) + + assert response.status_code == 200 + + response_data = response.json() + + assert response_data["collections"][0]["id"] == collection_id + + @pytest.mark.asyncio + async def test_tenant_validation_error(self, api_client, collection_in_db): + """ + Test that accessing wrong tenant's collection returns 404 + """ + collection_id = collection_in_db["tenant_collection"] + + # Try to access unexistent tenant for collection that exists in fake-tenant + response = await api_client.get(f"/fake-tenant-2/collections/{collection_id}") + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_invalid_tenant_format(self, api_client): + """ + Test handling of invalid tenant formats + """ + + # Non existent tenant should just show no collections + response = await api_client.get("/invalid-tenant-format/collections") + + assert response.status_code in [200, 404] + if response.status_code == 200: + response_data = response.json() + assert response_data["collections"] == [] + + @pytest.mark.asyncio + async def test_missing_tenant_parameter(self, api_client): + """ + Test behavior when tenant parameter is not supplied in route path + """ + + response = await api_client.get("/collections") + # Should return all collections (no tenant filtering) + assert response.status_code == 200