diff --git a/stac_fastapi/pgstac/types/search.py b/stac_fastapi/pgstac/types/search.py index b650dbe0..be40164e 100644 --- a/stac_fastapi/pgstac/types/search.py +++ b/stac_fastapi/pgstac/types/search.py @@ -1,8 +1,10 @@ """stac_fastapi.types.search module.""" -from pydantic import ValidationInfo, field_validator +from pydantic import ValidationInfo, field_validator, model_validator from stac_fastapi.types.search import BaseSearchPostRequest +from stac_fastapi.pgstac.utils import clean_exclude + class PgstacSearch(BaseSearchPostRequest): """Search model. @@ -22,3 +24,11 @@ def validate_query_uses_cql(cls, v: str, info: ValidationInfo): ) return v + + @model_validator(mode="after") + def validate_fields(self): + """Clean the exclude to give the include set precedence.""" + fields = getattr(self, "fields", None) + if fields and fields.include and fields.exclude: + fields.exclude = clean_exclude(fields.include, fields.exclude) + return self diff --git a/stac_fastapi/pgstac/utils.py b/stac_fastapi/pgstac/utils.py index 1a7bd49d..1e259e77 100644 --- a/stac_fastapi/pgstac/utils.py +++ b/stac_fastapi/pgstac/utils.py @@ -5,81 +5,123 @@ from stac_fastapi.types.stac import Item -def filter_fields( # noqa: C901 - item: Item, - include: set[str] | None = None, - exclude: set[str] | None = None, -) -> Item: - """Preserve and remove fields as indicated by the fields extension include/exclude sets. +def clean_exclude( + include: set[str], + exclude: set[str], +) -> set[str]: + """Clean the exclude set to ensure precedence of the include set. + + Cleaning includes: + - Removing any fields from the exclude set that are also in the include set, since + the include set takes precedence. + - Removing any fields from the exclude set that are parent paths of fields in the include set, + since including a sub-field of an excluded parent field should take precedence. + """ + intersection = include.intersection(exclude) + if intersection: + exclude = exclude - intersection + for field_excluded in exclude: + for field_included in include: + if field_included.startswith(field_excluded + "."): + exclude = exclude - {field_excluded} + pass + return exclude - Returns a shallow copy of the Item with the fields filtered. - This will not perform a deep copy; values of the original item will be referenced - in the return item. +def dict_deep_update(merge_to: dict[str, Any], merge_from: dict[str, Any]) -> None: + """Perform a deep update of two dicts. + + merge_to is updated in-place with the values from merge_from. + merge_from values take precedence over existing values in merge_to. """ - if not include and not exclude: - return item + for k, v in merge_from.items(): + if ( + k in merge_to + and isinstance(merge_to[k], dict) + and isinstance(merge_from[k], dict) + ): + dict_deep_update(merge_to[k], merge_from[k]) + else: + merge_to[k] = v + +def include_fields(source: dict[str, Any], fields: set[str]) -> dict[str, Any]: # Build a shallow copy of included fields on an item, or a sub-tree of an item - def include_fields(source: dict[str, Any], fields: set[str] | None) -> dict[str, Any]: - if not fields: - return source - - clean_item: dict[str, Any] = {} - for key_path in fields or []: - key_path_parts = key_path.split(".") - key_root = key_path_parts[0] - if key_root in source: - if isinstance(source[key_root], dict) and len(key_path_parts) > 1: - # The root of this key path on the item is a dict, and the - # key path indicates a sub-key to be included. Walk the dict - # from the root key and get the full nested value to include. - value = include_fields( - source[key_root], fields={".".join(key_path_parts[1:])} - ) - - if isinstance(clean_item.get(key_root), dict): - # A previously specified key and sub-keys may have been included - # already, so do a deep merge update if the root key already exists. - dict_deep_update(clean_item[key_root], value) - else: - # The root key does not exist, so add it. Fields - # extension only allows nested referencing on dicts, so - # this won't overwrite anything. - clean_item[key_root] = value + if not fields: + return source + + clean_item: dict[str, Any] = {} + for key_path in fields: + key_path_parts = key_path.split(".") + key_root = key_path_parts[0] + if key_root in source: + if isinstance(source[key_root], dict) and len(key_path_parts) > 1: + # The root of this key path on the item is a dict, and the + # key path indicates a sub-key to be included. Walk the dict + # from the root key and get the full nested value to include. + value = include_fields( + source[key_root], fields={".".join(key_path_parts[1:])} + ) + + if isinstance(clean_item.get(key_root), dict): + # A previously specified key and sub-keys may have been included + # already, so do a deep merge update if the root key already exists. + dict_deep_update(clean_item[key_root], value) else: - # The item value to include is not a dict, or, it is a dict but the - # key path is for the whole value, not a sub-key. Include the entire - # value in the cleaned item. - clean_item[key_root] = source[key_root] + # The root key does not exist, so add it. Fields + # extension only allows nested referencing on dicts, so + # this won't overwrite anything. + clean_item[key_root] = value else: - # The key, or root key of a multi-part key, is not present in the item, - # so it is ignored - pass + # The item value to include is not a dict, or, it is a dict but the + # key path is for the whole value, not a sub-key. Include the entire + # value in the cleaned item. + clean_item[key_root] = source[key_root] + else: + # The key, or root key of a multi-part key, is not present in the item, + # so it is ignored + pass - return clean_item + return clean_item + +def exclude_fields(source: dict[str, Any], fields: set[str]) -> None: # For an item built up for included fields, remove excluded fields. This # modifies `source` in place. - def exclude_fields(source: dict[str, Any], fields: set[str] | None) -> None: - for key_path in fields or []: - key_path_part = key_path.split(".") - key_root = key_path_part[0] - if key_root in source: - if isinstance(source[key_root], dict) and len(key_path_part) > 1: - # Walk the nested path of this key to remove the leaf-key - exclude_fields(source[key_root], fields={".".join(key_path_part[1:])}) - # If, after removing the leaf-key, the root is now an empty - # dict, remove it entirely - if not source[key_root]: - del source[key_root] - else: - # The key's value is not a dict, or there is no sub-key to remove. The - # entire key can be removed from the source. - source.pop(key_root, None) + for key_path in fields: + key_path_part = key_path.split(".") + key_root = key_path_part[0] + if key_root in source: + if isinstance(source[key_root], dict) and len(key_path_part) > 1: + # Walk the nested path of this key to remove the leaf-key + exclude_fields(source[key_root], fields={".".join(key_path_part[1:])}) + # If, after removing the leaf-key, the root is now an empty + # dict, remove it entirely + if not source[key_root]: + del source[key_root] else: - # The key to remove does not exist on the source, so it is ignored - pass + # The key's value is not a dict, or there is no sub-key to remove. The + # entire key can be removed from the source. + source.pop(key_root, None) + else: + # The key to remove does not exist on the source, so it is ignored + pass + + +def filter_fields( # noqa: C901 + item: Item, + include: set[str], + exclude: set[str], +) -> Item: + """Preserve and remove fields as indicated by the fields extension include/exclude sets. + + Returns a shallow copy of the Item with the fields filtered. + + This will not perform a deep copy; values of the original item will be referenced + in the return item. + """ + if not include and not exclude: + return item clean_item = include_fields(dict(item), include) @@ -91,20 +133,3 @@ def exclude_fields(source: dict[str, Any], fields: set[str] | None) -> None: exclude_fields(clean_item, exclude) return cast(Item, clean_item) - - -def dict_deep_update(merge_to: dict[str, Any], merge_from: dict[str, Any]) -> None: - """Perform a deep update of two dicts. - - merge_to is updated in-place with the values from merge_from. - merge_from values take precedence over existing values in merge_to. - """ - for k, v in merge_from.items(): - if ( - k in merge_to - and isinstance(merge_to[k], dict) - and isinstance(merge_from[k], dict) - ): - dict_deep_update(merge_to[k], merge_from[k]) - else: - merge_to[k] = v diff --git a/tests/resources/test_item.py b/tests/resources/test_item.py index 1e9f8957..a95685f8 100644 --- a/tests/resources/test_item.py +++ b/tests/resources/test_item.py @@ -1228,7 +1228,9 @@ async def test_field_extension_exclude_and_include( resp = await app_client.post("/search", json=body) resp_json = resp.json() - assert "properties" not in resp_json["features"][0] + assert "properties" in resp_json["features"][0] + assert "eo:cloud_cover" in resp_json["features"][0]["properties"].keys() + assert len(resp_json["features"][0]["properties"].keys()) == 1 async def test_field_extension_exclude_default_includes( @@ -1310,7 +1312,28 @@ async def test_field_extension_exclude_deeply_nested_included_subkeys( resp_assets = resp_json["features"][0]["assets"] assert "type" in resp_assets["ANG"] - assert "href" not in resp_assets["ANG"] + assert "href" in resp_assets["ANG"] + + +async def test_field_extension_exclude_root_of_included_subkeys( + app_client, load_test_item, load_test_collection +): + """Test that a root key of included nested object is not excluded""" + body = { + "fields": { + "include": ["assets.ANG.type"], + "exclude": ["assets.ANG"], + } + } + + resp = await app_client.post("/search", json=body) + assert resp.status_code == 200 + resp_json = resp.json() + + resp_assets = resp_json["features"][0]["assets"] + assert "ANG" in resp_assets + assert "type" in resp_assets["ANG"] + assert len(resp_assets["ANG"].keys()) == 1 async def test_field_extension_exclude_links( diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..03318b55 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,87 @@ +import pytest + +import stac_fastapi.pgstac.utils as utils + + +def test_clean_exclude(): + include = {"a", "b", "c.c1"} + exclude = {"a", "c", "d"} + res = utils.clean_exclude(include, exclude) + exp = {"d"} + assert res == exp + + +def test_dict_deep_update(): + dict_0 = {"a": 0, "b": 0, "c": {"c1": 0, "c2": 0}, "d": {"d1": 0}, "e": 0, "f": 0} + dict_1 = {"b": 1, "c": {"c2": 1, "c3": 1}, "d": 1, "e": {"e1": 1}, "f": 1} + utils.dict_deep_update(dict_0, dict_1) + exp = { + "a": 0, + "b": 1, + "c": {"c1": 0, "c2": 1, "c3": 1}, + "d": 1, + "e": {"e1": 1}, + "f": 1, + } + assert dict_0 == exp + + +def test_include_fields_no_fields(): + source = {"a": 0} + res = utils.include_fields(source, fields=set()) + assert res == source + + +def test_include_fields(): + source = {"a": 0, "b": 0, "c": {"c1": 0, "c2": 0, "c3": 0}, "d": {"d1": 0}, "e": 0} + fields = {"a", "c.c1", "c.c2", "d", "f"} + res = utils.include_fields(source, fields=fields) + exp = {"a": 0, "c": {"c1": 0, "c2": 0}, "d": {"d1": 0}} + assert res == exp + + +def test_exclude_fields(): + source = { + "a": 0, + "b": 0, + "c": {"c1": 0, "c2": 0}, + "d": {"d1": 0}, + "e": {"e1": 0}, + "f": 0, + } + fields = {"a", "c.c1", "d.d1", "e"} + utils.exclude_fields(source, fields=fields) + exp = {"b": 0, "c": {"c2": 0}, "f": 0} + assert source == exp + + +def test_filter_fields_no_included_properties(): + item = utils.Item( + id="test_id", + collection="test_collection", + properties={"prop_1": 0, "prop_2": 0}, + ) + res = utils.filter_fields(item, include={"missing_field"}, exclude=set()) + exp = utils.Item(id="test_id", collection="test_collection") + assert res == exp + + +@pytest.mark.parametrize( + "include, exclude, exp", + [ + ({"field_a", "field_b"}, {"field_a"}, {"field_b": "b"}), + ({"field_a"}, {"field_a"}, {}), + ({"properties"}, {"properties.prop_1"}, {"properties": {"prop_2": 0}}), + ({"properties.prop_1"}, {"properties"}, {}), + ], +) +def test_filter_fields(include, exclude, exp): + source = utils.Item( + id="test_id", + collection="test_collection", + properties={"prop_1": 0, "prop_2": 0}, + field_a="a", + field_b="b", + ) + res = utils.filter_fields(source, include=include, exclude=exclude) + assert res == utils.Item(**exp)