diff --git a/Dockerfile b/Dockerfile index 454fd161..8a28dbc4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,8 +23,7 @@ COPY scripts/wait-for-it.sh scripts/wait-for-it.sh COPY pyproject.toml pyproject.toml COPY README.md README.md -RUN python -m pip install .[server] -RUN rm -rf stac_fastapi .toml README.md +RUN python -m pip install -e .[server,catalogs] RUN groupadd -g 1000 user && \ useradd -u 1000 -g user -s /bin/bash -m user diff --git a/Dockerfile.tests b/Dockerfile.tests index 2dcceee5..097c3e77 100644 --- a/Dockerfile.tests +++ b/Dockerfile.tests @@ -16,4 +16,4 @@ USER newuser WORKDIR /app COPY . /app -RUN python -m pip install . --user --group dev +RUN python -m pip install .[catalogs] --user --group dev diff --git a/Makefile b/Makefile index 65fa32f8..e4d13d2b 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ run = docker compose run --rm \ -e APP_PORT=${APP_PORT} \ app -runtests = docker compose run --rm tests +runtests = docker compose -f compose-tests.yml run --rm tests .PHONY: image image: @@ -22,7 +22,7 @@ docker-run: image .PHONY: docker-run-nginx-proxy docker-run-nginx-proxy: - docker compose -f docker-compose.yml -f docker-compose.nginx.yml up + docker compose -f compose.yml -f docker-compose.nginx.yml up .PHONY: docker-shell docker-shell: @@ -32,6 +32,10 @@ docker-shell: test: $(runtests) /bin/bash -c 'export && python -m pytest /app/tests/ --log-cli-level $(LOG_LEVEL)' +.PHONY: test-catalogs +test-catalogs: + $(runtests) /bin/bash -c 'export && python -m pytest /app/tests/test_catalogs.py -v --log-cli-level $(LOG_LEVEL)' + .PHONY: run-database run-database: docker compose run --rm database diff --git a/docker-compose.yml b/compose-tests.yml similarity index 87% rename from docker-compose.yml rename to compose-tests.yml index 5aec9a9e..052833aa 100644 --- a/docker-compose.yml +++ b/compose-tests.yml @@ -1,6 +1,7 @@ services: app: image: stac-utils/stac-fastapi-pgstac + restart: always build: . environment: - APP_HOST=0.0.0.0 @@ -20,15 +21,19 @@ services: - DB_MAX_CONN_SIZE=1 - USE_API_HYDRATE=${USE_API_HYDRATE:-false} - ENABLE_TRANSACTIONS_EXTENSIONS=TRUE - ports: - - "8082:8082" + - ENABLE_CATALOGS_ROUTE=TRUE + # ports: + # - "8082:8082" depends_on: - database - command: bash -c "scripts/wait-for-it.sh database:5432 && python -m stac_fastapi.pgstac.app" + command: bash -c "scripts/wait-for-it.sh database:5432 && uvicorn stac_fastapi.pgstac.app:app --host 0.0.0.0 --port 8082 --reload" develop: watch: - - action: rebuild + - action: sync path: ./stac_fastapi/pgstac + target: /app/stac_fastapi/pgstac + - action: rebuild + path: ./setup.py tests: image: stac-utils/stac-fastapi-pgstac-test @@ -40,6 +45,7 @@ services: - DB_MIN_CONN_SIZE=1 - DB_MAX_CONN_SIZE=1 - USE_API_HYDRATE=${USE_API_HYDRATE:-false} + - ENABLE_CATALOGS_ROUTE=TRUE command: bash -c "python -m pytest -s -vv" database: diff --git a/compose.yml b/compose.yml new file mode 100644 index 00000000..869ae6ef --- /dev/null +++ b/compose.yml @@ -0,0 +1,91 @@ +services: + app: + image: stac-utils/stac-fastapi-pgstac + restart: always + build: . + environment: + - APP_HOST=0.0.0.0 + - APP_PORT=8082 + - RELOAD=true + - ENVIRONMENT=local + - PGUSER=username + - PGPASSWORD=password + - PGDATABASE=postgis + - PGHOST=database + - PGPORT=5432 + - WEB_CONCURRENCY=10 + - VSI_CACHE=TRUE + - GDAL_HTTP_MERGE_CONSECUTIVE_RANGES=YES + - GDAL_DISABLE_READDIR_ON_OPEN=EMPTY_DIR + - DB_MIN_CONN_SIZE=1 + - DB_MAX_CONN_SIZE=1 + - USE_API_HYDRATE=${USE_API_HYDRATE:-false} + - ENABLE_TRANSACTIONS_EXTENSIONS=TRUE + - ENABLE_CATALOGS_ROUTE=TRUE + # ports: + # - "8082:8082" + depends_on: + - database + command: bash -c "scripts/wait-for-it.sh database:5432 && uvicorn stac_fastapi.pgstac.app:app --host 0.0.0.0 --port 8082 --reload" + develop: + watch: + - action: sync + path: ./stac_fastapi/pgstac + target: /app/stac_fastapi/pgstac + - action: rebuild + path: ./setup.py + + database: + image: ghcr.io/stac-utils/pgstac:v0.9.8 + environment: + - POSTGRES_USER=username + - POSTGRES_PASSWORD=password + - POSTGRES_DB=postgis + - PGUSER=username + - PGPASSWORD=password + - PGDATABASE=postgis + ports: + - "5439:5432" + command: postgres -N 500 + + # Load joplin demo dataset into the PGStac Application + loadjoplin: + image: stac-utils/stac-fastapi-pgstac + environment: + - ENVIRONMENT=development + volumes: + - ./testdata:/tmp/testdata + - ./scripts:/tmp/scripts + command: > + /bin/sh -c " + scripts/wait-for-it.sh -t 60 app:8082 && + python -m pip install pip -U && + python -m pip install requests && + python /tmp/scripts/ingest_joplin.py http://app:8082 + " + depends_on: + - database + - app + + nginx: + image: nginx + ports: + - ${STAC_FASTAPI_NGINX_PORT:-8080}:80 + volumes: + - ./nginx.conf:/etc/nginx/nginx.conf + depends_on: + - app-nginx + command: [ "nginx-debug", "-g", "daemon off;" ] + + app-nginx: + extends: + service: app + command: > + bash -c " + scripts/wait-for-it.sh database:5432 && + uvicorn stac_fastapi.pgstac.app:app --host 0.0.0.0 --port 8082 --proxy-headers --forwarded-allow-ips=* --root-path=/api/v1/pgstac + " + +networks: + default: + name: stac-fastapi-network diff --git a/pyproject.toml b/pyproject.toml index cbd87da2..253dbe62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "stac-fastapi-pgstac" description = "An implementation of STAC API based on the FastAPI framework and using the pgstac backend." readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.12" license = "MIT" authors = [ { name = "David Bitner", email = "david@developmentseed.org" }, @@ -55,6 +55,9 @@ validation = [ server = [ "uvicorn[standard]==0.38.0" ] +catalogs = [ + "stac-fastapi-catalogs-extension>=0.1.2", +] [dependency-groups] dev = [ diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 844bd49f..43303fd3 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -5,6 +5,7 @@ If the variable is not set, enables all extensions. """ +import logging import os from contextlib import asynccontextmanager from typing import cast @@ -45,13 +46,35 @@ from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.core import CoreCrudClient, health_check from stac_fastapi.pgstac.db import close_db_connection, connect_to_db -from stac_fastapi.pgstac.extensions import FreeTextExtension, QueryExtension +from stac_fastapi.pgstac.extensions import ( + DatabaseLogic, + FreeTextExtension, + QueryExtension, +) +from stac_fastapi.pgstac.extensions.catalogs.catalogs_client import CatalogsClient from stac_fastapi.pgstac.extensions.filter import FiltersClient from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch +logger = logging.getLogger(__name__) + +# Optional catalogs extension (optional dependency) +try: + from stac_fastapi_catalogs_extension import CatalogsExtension +except ImportError: + CatalogsExtension = None + settings = Settings() + +def _is_env_flag_enabled(name: str) -> bool: + """Return True if the given env var is enabled. + + Accepts common truthy values ("yes", "true", "1") case-insensitively. + """ + return os.environ.get(name, "").lower() in ("yes", "true", "1") + + # search extensions search_extensions_map: dict[str, ApiExtension] = { "query": QueryExtension(), @@ -98,11 +121,7 @@ application_extensions: list[ApiExtension] = [] -with_transactions = os.environ.get("ENABLE_TRANSACTIONS_EXTENSIONS", "").lower() in [ - "yes", - "true", - "1", -] +with_transactions = _is_env_flag_enabled("ENABLE_TRANSACTIONS_EXTENSIONS") if with_transactions: application_extensions.append( TransactionExtension( @@ -158,6 +177,27 @@ collections_get_request_model = collection_search_extension.GET application_extensions.append(collection_search_extension) +# Optional catalogs route +ENABLE_CATALOGS_ROUTE = _is_env_flag_enabled("ENABLE_CATALOGS_ROUTE") +logger.info("ENABLE_CATALOGS_ROUTE is set to %s", ENABLE_CATALOGS_ROUTE) + +if ENABLE_CATALOGS_ROUTE: + if CatalogsExtension is None: + logger.warning( + "ENABLE_CATALOGS_ROUTE is set to true, but the catalogs extension is not installed. " + "Please install it with: pip install stac-fastapi-core[catalogs].", + ) + else: + try: + catalogs_extension = CatalogsExtension( + client=CatalogsClient(database=DatabaseLogic()), + enable_transactions=with_transactions, + ) + application_extensions.append(catalogs_extension) + print("CatalogsExtension enabled successfully.") + except Exception as e: # pragma: no cover - defensive + logger.warning("Failed to initialize CatalogsExtension: %s", e) + @asynccontextmanager async def lifespan(app: FastAPI): diff --git a/stac_fastapi/pgstac/extensions/__init__.py b/stac_fastapi/pgstac/extensions/__init__.py index 6c2812b6..8c5738f2 100644 --- a/stac_fastapi/pgstac/extensions/__init__.py +++ b/stac_fastapi/pgstac/extensions/__init__.py @@ -1,7 +1,15 @@ """pgstac extension customisations.""" +from .catalogs.catalogs_client import CatalogsClient +from .catalogs.catalogs_database_logic import DatabaseLogic from .filter import FiltersClient from .free_text import FreeTextExtension from .query import QueryExtension -__all__ = ["QueryExtension", "FiltersClient", "FreeTextExtension"] +__all__ = [ + "QueryExtension", + "FiltersClient", + "FreeTextExtension", + "CatalogsClient", + "DatabaseLogic", +] diff --git a/stac_fastapi/pgstac/extensions/catalogs/catalogs_client.py b/stac_fastapi/pgstac/extensions/catalogs/catalogs_client.py new file mode 100644 index 00000000..16830a5f --- /dev/null +++ b/stac_fastapi/pgstac/extensions/catalogs/catalogs_client.py @@ -0,0 +1,303 @@ +"""Catalogs client implementation for pgstac.""" + +import logging +from typing import Any, cast + +import attr +from fastapi import Request +from stac_fastapi_catalogs_extension.client import AsyncBaseCatalogsClient +from stac_fastapi.types import stac as stac_types +from starlette.responses import JSONResponse + +from stac_fastapi.types.errors import NotFoundError + +logger = logging.getLogger(__name__) + + +@attr.s +class CatalogsClient(AsyncBaseCatalogsClient): + """Catalogs client implementation for pgstac. + + This client implements the AsyncBaseCatalogsClient interface and delegates + to the database layer for all catalog operations. + """ + + database: Any = attr.ib() + + async def get_catalogs( + self, + limit: int | None = None, + token: str | None = None, + request: Request | None = None, + **kwargs, + ) -> JSONResponse: + """Get all catalogs.""" + limit = limit or 10 + catalogs_list, next_token, total_hits = await self.database.get_all_catalogs( + token=token, + limit=limit, + request=request, + ) + + return JSONResponse( + content={ + "catalogs": catalogs_list or [], + "links": [], + "numberMatched": total_hits, + "numberReturned": len(catalogs_list) if catalogs_list else 0, + } + ) + + async def get_catalog( + self, catalog_id: str, request: Request | None = None, **kwargs + ) -> JSONResponse: + """Get a specific catalog by ID.""" + try: + catalog = await self.database.find_catalog(catalog_id, request=request) + return JSONResponse(content=catalog) + except NotFoundError: + raise + + async def create_catalog( + self, catalog: dict, request: Request | None = None, **kwargs + ) -> stac_types.Catalog: + """Create a new catalog.""" + # Convert Pydantic model to dict if needed + catalog_dict = cast(stac_types.Catalog, catalog.model_dump(mode="json") if hasattr(catalog, "model_dump") else catalog) + + await self.database.create_catalog(dict(catalog_dict), refresh=True, request=request) + return catalog_dict + + async def update_catalog( + self, catalog_id: str, catalog: dict, request: Request | None = None, **kwargs + ) -> stac_types.Catalog: + """Update an existing catalog.""" + # Convert Pydantic model to dict if needed + catalog_dict = cast(stac_types.Catalog, catalog.model_dump(mode="json") if hasattr(catalog, "model_dump") else catalog) + + await self.database.create_catalog(dict(catalog_dict), refresh=True, request=request) + return catalog_dict + + async def delete_catalog( + self, catalog_id: str, request: Request | None = None, **kwargs + ) -> None: + """Delete a catalog.""" + await self.database.delete_catalog(catalog_id, refresh=True, request=request) + + async def get_catalog_collections( + self, + catalog_id: str, + limit: int | None = None, + token: str | None = None, + request: Request | None = None, + **kwargs, + ) -> JSONResponse: + """Get collections in a catalog.""" + limit = limit or 10 + collections_list, total_hits, next_token = await self.database.get_catalog_collections( + catalog_id=catalog_id, + limit=limit, + token=token, + request=request, + ) + return JSONResponse( + content={ + "collections": collections_list or [], + "links": [], + "numberMatched": total_hits, + "numberReturned": len(collections_list) if collections_list else 0, + } + ) + + async def get_sub_catalogs( + self, + catalog_id: str, + limit: int | None = None, + token: str | None = None, + request: Request | None = None, + **kwargs, + ) -> JSONResponse: + """Get sub-catalogs.""" + limit = limit or 10 + catalogs_list, total_hits, next_token = await self.database.get_catalog_catalogs( + catalog_id=catalog_id, + limit=limit, + token=token, + request=request, + ) + return JSONResponse( + content={ + "catalogs": catalogs_list or [], + "links": [], + "numberMatched": total_hits, + "numberReturned": len(catalogs_list) if catalogs_list else 0, + } + ) + + async def create_sub_catalog( + self, catalog_id: str, catalog: dict, request: Request | None = None, **kwargs + ) -> JSONResponse: + """Create a sub-catalog.""" + # Convert Pydantic model to dict if needed + if hasattr(catalog, "model_dump"): + catalog_dict = catalog.model_dump(mode="json") + else: + catalog_dict = dict(catalog) if not isinstance(catalog, dict) else catalog + + catalog_dict["parent_ids"] = [catalog_id] + await self.database.create_catalog(catalog_dict, refresh=True, request=request) + return JSONResponse(content=catalog_dict, status_code=201) + + async def create_catalog_collection( + self, catalog_id: str, collection: dict, request: Request | None = None, **kwargs + ) -> JSONResponse: + """Create a collection in a catalog.""" + # Convert Pydantic model to dict if needed + if hasattr(collection, "model_dump"): + collection_dict = collection.model_dump(mode="json") + else: + collection_dict = dict(collection) if not isinstance(collection, dict) else collection + + collection_dict["parent_ids"] = [catalog_id] + await self.database.create_collection(collection_dict, refresh=True, request=request) + return JSONResponse(content=collection_dict, status_code=201) + + async def get_catalog_collection( + self, + catalog_id: str, + collection_id: str, + request: Request | None = None, + **kwargs, + ) -> JSONResponse: + """Get a collection from a catalog.""" + collection = await self.database.get_catalog_collection( + catalog_id=catalog_id, + collection_id=collection_id, + request=request, + ) + return JSONResponse(content=collection) + + async def unlink_catalog_collection( + self, + catalog_id: str, + collection_id: str, + request: Request | None = None, + **kwargs, + ) -> None: + """Unlink a collection from a catalog.""" + collection = await self.database.get_catalog_collection( + catalog_id=catalog_id, + collection_id=collection_id, + request=request, + ) + if "parent_ids" in collection: + collection["parent_ids"] = [ + pid for pid in collection["parent_ids"] if pid != catalog_id + ] + await self.database.update_collection( + collection_id, collection, refresh=True, request=request + ) + + async def get_catalog_collection_items( + self, + catalog_id: str, + collection_id: str, + limit: int | None = None, + token: str | None = None, + request: Request | None = None, + **kwargs, + ) -> JSONResponse: + """Get items from a collection in a catalog.""" + limit = limit or 10 + items, total, next_token = await self.database.get_catalog_collection_items( + catalog_id=catalog_id, + collection_id=collection_id, + limit=limit, + token=token, + request=request, + ) + return JSONResponse( + content={ + "type": "FeatureCollection", + "features": items or [], + "links": [], + "numberMatched": total, + "numberReturned": len(items) if items else 0, + } + ) + + async def get_catalog_collection_item( + self, + catalog_id: str, + collection_id: str, + item_id: str, + request: Request | None = None, + **kwargs, + ) -> JSONResponse: + """Get a specific item from a collection in a catalog.""" + item = await self.database.get_catalog_collection_item( + catalog_id=catalog_id, + collection_id=collection_id, + item_id=item_id, + request=request, + ) + return JSONResponse(content=item) + + async def get_catalog_children( + self, + catalog_id: str, + limit: int | None = None, + token: str | None = None, + request: Request | None = None, + **kwargs, + ) -> JSONResponse: + """Get all children of a catalog.""" + limit = limit or 10 + children_list, total_hits, next_token = await self.database.get_catalog_children( + catalog_id=catalog_id, + limit=limit, + token=token, + request=request, + ) + return JSONResponse( + content={ + "children": children_list or [], + "links": [], + "numberMatched": total_hits, + "numberReturned": len(children_list) if children_list else 0, + } + ) + + async def get_catalog_conformance( + self, catalog_id: str, request: Request | None = None, **kwargs + ) -> JSONResponse: + """Get conformance classes for a catalog.""" + return JSONResponse( + content={ + "conformsTo": [ + "https://api.stacspec.org/v1.0.0/core", + "https://api.stacspec.org/v1.0.0/multi-tenant-catalogs", + ] + } + ) + + async def get_catalog_queryables( + self, catalog_id: str, request: Request | None = None, **kwargs + ) -> JSONResponse: + """Get queryables for a catalog.""" + return JSONResponse(content={"queryables": []}) + + async def unlink_sub_catalog( + self, + catalog_id: str, + sub_catalog_id: str, + request: Request | None = None, + **kwargs, + ) -> None: + """Unlink a sub-catalog from its parent.""" + sub_catalog = await self.database.find_catalog(sub_catalog_id, request=request) + if "parent_ids" in sub_catalog: + sub_catalog["parent_ids"] = [ + pid for pid in sub_catalog["parent_ids"] if pid != catalog_id + ] + await self.database.create_catalog(sub_catalog, refresh=True, request=request) diff --git a/stac_fastapi/pgstac/extensions/catalogs/catalogs_database_logic.py b/stac_fastapi/pgstac/extensions/catalogs/catalogs_database_logic.py new file mode 100644 index 00000000..7024b1f9 --- /dev/null +++ b/stac_fastapi/pgstac/extensions/catalogs/catalogs_database_logic.py @@ -0,0 +1,449 @@ +import json +import logging +from typing import Any, cast + +from buildpg import render +from fastapi import Request +from stac_fastapi.pgstac.db import dbfunc +from stac_fastapi.types import stac as stac_types +from stac_fastapi.types.errors import NotFoundError + +logger = logging.getLogger(__name__) + + +class DatabaseLogic: + """Database logic for catalogs extension using PGStac.""" + + async def get_all_catalogs( + self, + token: str | None, + limit: int, + request: Any = None, + sort: list[dict[str, Any]] | None = None, + ) -> tuple[list[dict[str, Any]], str | None, int | None]: + """Retrieve a list of catalogs from PGStac, supporting pagination. + + Args: + token (str | None): The pagination token. + limit (int): The number of results to return. + request (Any, optional): The FastAPI request object. Defaults to None. + sort (list[dict[str, Any]] | None, optional): Optional sort parameter. Defaults to None. + + Returns: + A tuple of (catalogs, next pagination token if any, optional count). + """ + if request is None: + logger.debug("No request object provided to get_all_catalogs") + return [], None, None + + try: + async with request.app.state.get_connection(request, "r") as conn: + logger.debug("Attempting to fetch all catalogs from database") + q, p = render( + """ + SELECT content + FROM collections + WHERE content->>'type' = 'Catalog' + ORDER BY id + LIMIT :limit OFFSET 0; + """, + limit=limit, + ) + rows = await conn.fetch(q, *p) + catalogs = [row[0] for row in rows] if rows else [] + logger.info(f"Successfully fetched {len(catalogs)} catalogs") + except Exception as e: + logger.warning(f"Error fetching all catalogs: {e}") + catalogs = [] + + return catalogs, None, len(catalogs) if catalogs else None + + async def find_catalog(self, catalog_id: str, request: Any = None) -> dict[str, Any]: + """Find a catalog by ID. + + Args: + catalog_id: The catalog ID to find. + request: The FastAPI request object. + + Returns: + The catalog dictionary. + + Raises: + NotFoundError: If the catalog is not found. + """ + if request is None: + raise NotFoundError(f"Catalog {catalog_id} not found") + + try: + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT content + FROM collections + WHERE id = :id AND content->>'type' = 'Catalog'; + """, + id=catalog_id, + ) + row = await conn.fetchval(q, *p) + catalog = row if row else None + except Exception: + catalog = None + + if catalog is None: + raise NotFoundError(f"Catalog {catalog_id} not found") + + return catalog + + async def create_catalog( + self, catalog: dict[str, Any], refresh: bool = False, request: Any = None + ) -> None: + """Create or update a catalog. + + Args: + catalog: The catalog dictionary. + refresh: Whether to refresh after creation. + request: The FastAPI request object. + """ + if request is None: + return + + try: + async with request.app.state.get_connection(request, "w") as conn: + await dbfunc(conn, "create_collection", dict(catalog)) + except Exception as e: + logger.warning(f"Error creating catalog: {e}") + + async def delete_catalog( + self, catalog_id: str, refresh: bool = False, request: Any = None + ) -> None: + """Delete a catalog. + + Args: + catalog_id: The catalog ID to delete. + refresh: Whether to refresh after deletion. + request: The FastAPI request object. + """ + if request is None: + return + + try: + async with request.app.state.get_connection(request, "w") as conn: + await dbfunc(conn, "delete_collection", catalog_id) + except Exception as e: + logger.warning(f"Error deleting catalog: {e}") + + async def get_catalog_children( + self, + catalog_id: str, + limit: int = 10, + token: str | None = None, + request: Any = None, + ) -> tuple[list[dict[str, Any]], int | None, str | None]: + """Get all children (catalogs and collections) of a catalog. + + Args: + catalog_id: The parent catalog ID. + limit: The number of results to return. + token: The pagination token. + request: The FastAPI request object. + + Returns: + A tuple of (children list, total count, next token). + """ + if request is None: + return [], None, None + + try: + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT content + FROM collections + WHERE content->'parent_ids' @> :parent_id::jsonb + ORDER BY content->>'type' DESC, id + LIMIT :limit OFFSET 0; + """, + parent_id=f'"{catalog_id}"', + limit=limit, + ) + rows = await conn.fetch(q, *p) + children = [row[0] for row in rows] if rows else [] + except Exception: + children = [] + + return children[:limit], len(children) if children else None, None + + async def get_catalog_collections( + self, + catalog_id: str, + limit: int = 10, + token: str | None = None, + request: Any = None, + ) -> tuple[list[dict[str, Any]], int | None, str | None]: + """Get collections linked to a catalog. + + Args: + catalog_id: The catalog ID. + limit: The number of results to return. + token: The pagination token. + request: The FastAPI request object. + + Returns: + A tuple of (collections list, total count, next token). + """ + if request is None: + return [], None, None + + try: + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT content + FROM collections + WHERE content->>'type' = 'Collection' AND content->'parent_ids' @> :parent_id::jsonb + ORDER BY id + LIMIT :limit OFFSET 0; + """, + parent_id=f'"{catalog_id}"', + limit=limit, + ) + rows = await conn.fetch(q, *p) + collections = [row[0] for row in rows] if rows else [] + except Exception: + collections = [] + + return collections[:limit], len(collections) if collections else None, None + + async def get_catalog_catalogs( + self, + catalog_id: str, + limit: int = 10, + token: str | None = None, + request: Any = None, + ) -> tuple[list[dict[str, Any]], int | None, str | None]: + """Get sub-catalogs of a catalog. + + Args: + catalog_id: The parent catalog ID. + limit: The number of results to return. + token: The pagination token. + request: The FastAPI request object. + + Returns: + A tuple of (catalogs list, total count, next token). + """ + if request is None: + return [], None, None + + try: + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT content + FROM collections + WHERE content->>'type' = 'Catalog' AND content->'parent_ids' @> :parent_id::jsonb + ORDER BY id + LIMIT :limit OFFSET 0; + """, + parent_id=f'"{catalog_id}"', + limit=limit, + ) + rows = await conn.fetch(q, *p) + catalogs = [row[0] for row in rows] if rows else [] + except Exception: + catalogs = [] + + return catalogs[:limit], len(catalogs) if catalogs else None, None + + async def find_collection( + self, collection_id: str, request: Any = None + ) -> dict[str, Any]: + """Find a collection by ID. + + Args: + collection_id: The collection ID to find. + request: The FastAPI request object. + + Returns: + The collection dictionary. + + Raises: + NotFoundError: If the collection is not found. + """ + if request is None: + raise NotFoundError(f"Collection {collection_id} not found") + + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM get_collection(:id::text); + """, + id=collection_id, + ) + collection = await conn.fetchval(q, *p) + + if collection is None: + raise NotFoundError(f"Collection {collection_id} not found") + + return collection + + async def create_collection( + self, collection: dict[str, Any], refresh: bool = False, request: Any = None + ) -> None: + """Create a collection. + + Args: + collection: The collection dictionary. + refresh: Whether to refresh after creation. + request: The FastAPI request object. + """ + if request is None: + return + + try: + async with request.app.state.get_connection(request, "w") as conn: + await dbfunc(conn, "create_collection", dict(collection)) + except Exception as e: + logger.warning(f"Error creating collection: {e}") + + async def update_collection( + self, + collection_id: str, + collection: dict[str, Any], + refresh: bool = False, + request: Any = None, + ) -> None: + """Update a collection. + + Args: + collection_id: The collection ID to update. + collection: The collection dictionary. + refresh: Whether to refresh after update. + request: The FastAPI request object. + """ + if request is None: + return + + async with request.app.state.get_connection(request, "w") as conn: + q, p = render( + """ + SELECT * FROM update_collection(:item::text::jsonb); + """, + item=json.dumps(collection), + ) + await conn.fetchval(q, *p) + + async def get_catalog_collection( + self, + catalog_id: str, + collection_id: str, + request: Any = None, + ) -> dict[str, Any]: + """Get a specific collection from a catalog. + + Args: + catalog_id: The catalog ID. + collection_id: The collection ID. + request: The FastAPI request object. + + Returns: + The collection dictionary. + + Raises: + NotFoundError: If the collection is not found. + """ + if request is None: + raise NotFoundError(f"Collection {collection_id} not found") + + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM get_collection(:id::text); + """, + id=collection_id, + ) + collection = await conn.fetchval(q, *p) + + if collection is None: + raise NotFoundError(f"Collection {collection_id} not found") + + return collection + + async def get_catalog_collection_items( + self, + catalog_id: str, + collection_id: str, + bbox: Any = None, + datetime: str | None = None, + limit: int = 10, + token: str | None = None, + request: Any = None, + **kwargs: Any, + ) -> tuple[list[dict[str, Any]], int | None, str | None]: + """Get items from a collection in a catalog. + + Args: + catalog_id: The catalog ID. + collection_id: The collection ID. + bbox: Bounding box filter. + datetime: Datetime filter. + limit: The number of results to return. + token: The pagination token. + request: The FastAPI request object. + **kwargs: Additional arguments. + + Returns: + A tuple of (items list, total count, next token). + """ + if request is None: + return [], None, None + + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM get_collection_items(:collection_id::text); + """, + collection_id=collection_id, + ) + items = await conn.fetchval(q, *p) or [] + + return items[:limit], len(items), None + + async def get_catalog_collection_item( + self, + catalog_id: str, + collection_id: str, + item_id: str, + request: Any = None, + ) -> dict[str, Any]: + """Get a specific item from a collection in a catalog. + + Args: + catalog_id: The catalog ID. + collection_id: The collection ID. + item_id: The item ID. + request: The FastAPI request object. + + Returns: + The item dictionary. + + Raises: + NotFoundError: If the item is not found. + """ + if request is None: + raise NotFoundError(f"Item {item_id} not found") + + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM get_item(:item_id::text, :collection_id::text); + """, + item_id=item_id, + collection_id=collection_id, + ) + item = await conn.fetchval(q, *p) + + if item is None: + raise NotFoundError(f"Item {item_id} not found") + + return item \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 29c16a2f..3dd09e62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,11 +46,22 @@ from stac_fastapi.pgstac.config import PostgresSettings, Settings from stac_fastapi.pgstac.core import CoreCrudClient, health_check from stac_fastapi.pgstac.db import close_db_connection, connect_to_db -from stac_fastapi.pgstac.extensions import FreeTextExtension, QueryExtension +from stac_fastapi.pgstac.extensions import ( + DatabaseLogic, + FreeTextExtension, + QueryExtension, +) +from stac_fastapi.pgstac.extensions.catalogs.catalogs_client import CatalogsClient from stac_fastapi.pgstac.extensions.filter import FiltersClient from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch +# Optional catalogs extension +try: + from stac_fastapi_catalogs_extension import CatalogsExtension +except ImportError: + CatalogsExtension = None + DATA_DIR = os.path.join(os.path.dirname(__file__), "data") @@ -129,6 +140,14 @@ def api_client(request): BulkTransactionExtension(client=BulkTransactionsClient()), ] + # Add catalogs extension if available + if CatalogsExtension is not None: + catalogs_extension = CatalogsExtension( + client=CatalogsClient(database=DatabaseLogic()), + enable_transactions=True, + ) + application_extensions.append(catalogs_extension) + search_extensions = [ QueryExtension(), SortExtension(), diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py new file mode 100644 index 00000000..25394b9c --- /dev/null +++ b/tests/test_catalogs.py @@ -0,0 +1,113 @@ +"""Tests for the catalogs extension.""" + +import logging +from urllib.parse import urlparse + +import pytest + +logger = logging.getLogger(__name__) + + +def has_router_prefix(app_client): + """Check if the app_client has a router prefix.""" + parsed = urlparse(str(app_client.base_url)) + return "/router_prefix" in parsed.path + + +@pytest.mark.asyncio +async def test_create_catalog(app_client): + """Test creating a catalog.""" + if has_router_prefix(app_client): + pytest.skip("Catalogs extension routes not registered with router prefix") + + catalog_data = { + "id": "test-catalog", + "type": "Catalog", + "description": "A test catalog", + "stac_version": "1.0.0", + "links": [], + } + + resp = await app_client.post( + "/catalogs", + json=catalog_data, + ) + assert resp.status_code == 201 + created_catalog = resp.json() + assert created_catalog["id"] == "test-catalog" + assert created_catalog["type"] == "Catalog" + assert created_catalog["description"] == "A test catalog" + + +@pytest.mark.asyncio +async def test_get_all_catalogs(app_client): + """Test getting all catalogs.""" + if has_router_prefix(app_client): + pytest.skip("Catalogs extension routes not registered with router prefix") + + # Create three catalogs + catalog_ids = ["test-catalog-1", "test-catalog-2", "test-catalog-3"] + for catalog_id in catalog_ids: + catalog_data = { + "id": catalog_id, + "type": "Catalog", + "description": f"Test catalog {catalog_id}", + "stac_version": "1.0.0", + "links": [], + } + + resp = await app_client.post( + "/catalogs", + json=catalog_data, + ) + assert resp.status_code == 201 + + # Now get all catalogs + resp = await app_client.get("/catalogs") + assert resp.status_code == 200 + data = resp.json() + assert "catalogs" in data + assert isinstance(data["catalogs"], list) + assert len(data["catalogs"]) >= 3 + + # Check that all three created catalogs are in the list + returned_catalog_ids = [cat.get("id") for cat in data["catalogs"]] + for catalog_id in catalog_ids: + assert catalog_id in returned_catalog_ids + + +@pytest.mark.asyncio +async def test_get_catalog_by_id(app_client): + """Test getting a specific catalog by ID.""" + if has_router_prefix(app_client): + pytest.skip("Catalogs extension routes not registered with router prefix") + + # First create a catalog + catalog_data = { + "id": "test-catalog-get", + "type": "Catalog", + "description": "A test catalog for getting", + "stac_version": "1.0.0", + "links": [], + } + + resp = await app_client.post( + "/catalogs", + json=catalog_data, + ) + assert resp.status_code == 201 + + # Now get the specific catalog + resp = await app_client.get("/catalogs/test-catalog-get") + assert resp.status_code == 200 + retrieved_catalog = resp.json() + assert retrieved_catalog["id"] == "test-catalog-get" + assert retrieved_catalog["type"] == "Catalog" + assert retrieved_catalog["description"] == "A test catalog for getting" + + +@pytest.mark.asyncio +async def test_get_nonexistent_catalog(app_client): + """Test getting a catalog that doesn't exist.""" + resp = await app_client.get("/catalogs/nonexistent-catalog-id") + assert resp.status_code == 404