Skip to content
Open
12 changes: 11 additions & 1 deletion stac_fastapi/pgstac/types/search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""stac_fastapi.types.search module."""

from pydantic import ValidationInfo, field_validator
from pydantic import ValidationInfo, field_validator, model_validator
from stac_fastapi.types.search import BaseSearchPostRequest

from stac_fastapi.pgstac.utils import clean_exclude


class PgstacSearch(BaseSearchPostRequest):
"""Search model.
Expand All @@ -22,3 +24,11 @@ def validate_query_uses_cql(cls, v: str, info: ValidationInfo):
)

return v

@model_validator(mode="after")
def validate_fields(self):
"""Clean the exclude to give the include set precedence."""
fields = getattr(self, "fields", None)
if fields and fields.include and fields.exclude:
fields.exclude = clean_exclude(fields.include, fields.exclude)
return self
185 changes: 105 additions & 80 deletions stac_fastapi/pgstac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,81 +5,123 @@
from stac_fastapi.types.stac import Item


def filter_fields( # noqa: C901
item: Item,
include: set[str] | None = None,
exclude: set[str] | None = None,
) -> Item:
"""Preserve and remove fields as indicated by the fields extension include/exclude sets.
def clean_exclude(
include: set[str],
exclude: set[str],
) -> set[str]:
"""Clean the exclude set to ensure precedence of the include set.

Cleaning includes:
- Removing any fields from the exclude set that are also in the include set, since
the include set takes precedence.
- Removing any fields from the exclude set that are parent paths of fields in the include set,
since including a sub-field of an excluded parent field should take precedence.
"""
intersection = include.intersection(exclude)
if intersection:
exclude = exclude - intersection
for field_excluded in exclude:
for field_included in include:
if field_included.startswith(field_excluded + "."):
exclude = exclude - {field_excluded}
pass
return exclude

Returns a shallow copy of the Item with the fields filtered.

This will not perform a deep copy; values of the original item will be referenced
in the return item.
def dict_deep_update(merge_to: dict[str, Any], merge_from: dict[str, Any]) -> None:
"""Perform a deep update of two dicts.

merge_to is updated in-place with the values from merge_from.
merge_from values take precedence over existing values in merge_to.
"""
if not include and not exclude:
return item
for k, v in merge_from.items():
if (
k in merge_to
and isinstance(merge_to[k], dict)
and isinstance(merge_from[k], dict)
):
dict_deep_update(merge_to[k], merge_from[k])
else:
merge_to[k] = v


def include_fields(source: dict[str, Any], fields: set[str]) -> dict[str, Any]:
# Build a shallow copy of included fields on an item, or a sub-tree of an item
def include_fields(source: dict[str, Any], fields: set[str] | None) -> dict[str, Any]:
if not fields:
return source

clean_item: dict[str, Any] = {}
for key_path in fields or []:
key_path_parts = key_path.split(".")
key_root = key_path_parts[0]
if key_root in source:
if isinstance(source[key_root], dict) and len(key_path_parts) > 1:
# The root of this key path on the item is a dict, and the
# key path indicates a sub-key to be included. Walk the dict
# from the root key and get the full nested value to include.
value = include_fields(
source[key_root], fields={".".join(key_path_parts[1:])}
)

if isinstance(clean_item.get(key_root), dict):
# A previously specified key and sub-keys may have been included
# already, so do a deep merge update if the root key already exists.
dict_deep_update(clean_item[key_root], value)
else:
# The root key does not exist, so add it. Fields
# extension only allows nested referencing on dicts, so
# this won't overwrite anything.
clean_item[key_root] = value
if not fields:
return source

clean_item: dict[str, Any] = {}
for key_path in fields:
key_path_parts = key_path.split(".")
key_root = key_path_parts[0]
if key_root in source:
if isinstance(source[key_root], dict) and len(key_path_parts) > 1:
# The root of this key path on the item is a dict, and the
# key path indicates a sub-key to be included. Walk the dict
# from the root key and get the full nested value to include.
value = include_fields(
source[key_root], fields={".".join(key_path_parts[1:])}
)

if isinstance(clean_item.get(key_root), dict):
# A previously specified key and sub-keys may have been included
# already, so do a deep merge update if the root key already exists.
dict_deep_update(clean_item[key_root], value)
else:
# The item value to include is not a dict, or, it is a dict but the
# key path is for the whole value, not a sub-key. Include the entire
# value in the cleaned item.
clean_item[key_root] = source[key_root]
# The root key does not exist, so add it. Fields
# extension only allows nested referencing on dicts, so
# this won't overwrite anything.
clean_item[key_root] = value
else:
# The key, or root key of a multi-part key, is not present in the item,
# so it is ignored
pass
# The item value to include is not a dict, or, it is a dict but the
# key path is for the whole value, not a sub-key. Include the entire
# value in the cleaned item.
clean_item[key_root] = source[key_root]
else:
# The key, or root key of a multi-part key, is not present in the item,
# so it is ignored
pass

return clean_item
return clean_item


def exclude_fields(source: dict[str, Any], fields: set[str]) -> None:
# For an item built up for included fields, remove excluded fields. This
# modifies `source` in place.
def exclude_fields(source: dict[str, Any], fields: set[str] | None) -> None:
for key_path in fields or []:
key_path_part = key_path.split(".")
key_root = key_path_part[0]
if key_root in source:
if isinstance(source[key_root], dict) and len(key_path_part) > 1:
# Walk the nested path of this key to remove the leaf-key
exclude_fields(source[key_root], fields={".".join(key_path_part[1:])})
# If, after removing the leaf-key, the root is now an empty
# dict, remove it entirely
if not source[key_root]:
del source[key_root]
else:
# The key's value is not a dict, or there is no sub-key to remove. The
# entire key can be removed from the source.
source.pop(key_root, None)
for key_path in fields:
key_path_part = key_path.split(".")
key_root = key_path_part[0]
if key_root in source:
if isinstance(source[key_root], dict) and len(key_path_part) > 1:
# Walk the nested path of this key to remove the leaf-key
exclude_fields(source[key_root], fields={".".join(key_path_part[1:])})
# If, after removing the leaf-key, the root is now an empty
# dict, remove it entirely
if not source[key_root]:
del source[key_root]
else:
# The key to remove does not exist on the source, so it is ignored
pass
# The key's value is not a dict, or there is no sub-key to remove. The
# entire key can be removed from the source.
source.pop(key_root, None)
else:
# The key to remove does not exist on the source, so it is ignored
pass


def filter_fields( # noqa: C901
item: Item,
include: set[str],
exclude: set[str],
) -> Item:
"""Preserve and remove fields as indicated by the fields extension include/exclude sets.

Returns a shallow copy of the Item with the fields filtered.

This will not perform a deep copy; values of the original item will be referenced
in the return item.
"""
if not include and not exclude:
return item

clean_item = include_fields(dict(item), include)

Expand All @@ -91,20 +133,3 @@ def exclude_fields(source: dict[str, Any], fields: set[str] | None) -> None:
exclude_fields(clean_item, exclude)

return cast(Item, clean_item)


def dict_deep_update(merge_to: dict[str, Any], merge_from: dict[str, Any]) -> None:
"""Perform a deep update of two dicts.

merge_to is updated in-place with the values from merge_from.
merge_from values take precedence over existing values in merge_to.
"""
for k, v in merge_from.items():
if (
k in merge_to
and isinstance(merge_to[k], dict)
and isinstance(merge_from[k], dict)
):
dict_deep_update(merge_to[k], merge_from[k])
else:
merge_to[k] = v
27 changes: 25 additions & 2 deletions tests/resources/test_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,9 @@ async def test_field_extension_exclude_and_include(

resp = await app_client.post("/search", json=body)
resp_json = resp.json()
assert "properties" not in resp_json["features"][0]
assert "properties" in resp_json["features"][0]
assert "eo:cloud_cover" in resp_json["features"][0]["properties"].keys()
assert len(resp_json["features"][0]["properties"].keys()) == 1


async def test_field_extension_exclude_default_includes(
Expand Down Expand Up @@ -1310,7 +1312,28 @@ async def test_field_extension_exclude_deeply_nested_included_subkeys(

resp_assets = resp_json["features"][0]["assets"]
assert "type" in resp_assets["ANG"]
assert "href" not in resp_assets["ANG"]
assert "href" in resp_assets["ANG"]


async def test_field_extension_exclude_root_of_included_subkeys(
app_client, load_test_item, load_test_collection
):
"""Test that a root key of included nested object is not excluded"""
body = {
"fields": {
"include": ["assets.ANG.type"],
"exclude": ["assets.ANG"],
}
}

resp = await app_client.post("/search", json=body)
assert resp.status_code == 200
resp_json = resp.json()

resp_assets = resp_json["features"][0]["assets"]
assert "ANG" in resp_assets
assert "type" in resp_assets["ANG"]
assert len(resp_assets["ANG"].keys()) == 1


async def test_field_extension_exclude_links(
Expand Down
87 changes: 87 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest

import stac_fastapi.pgstac.utils as utils


def test_clean_exclude():
include = {"a", "b", "c.c1"}
exclude = {"a", "c", "d"}
res = utils.clean_exclude(include, exclude)
exp = {"d"}
assert res == exp


def test_dict_deep_update():
dict_0 = {"a": 0, "b": 0, "c": {"c1": 0, "c2": 0}, "d": {"d1": 0}, "e": 0, "f": 0}
dict_1 = {"b": 1, "c": {"c2": 1, "c3": 1}, "d": 1, "e": {"e1": 1}, "f": 1}
utils.dict_deep_update(dict_0, dict_1)
exp = {
"a": 0,
"b": 1,
"c": {"c1": 0, "c2": 1, "c3": 1},
"d": 1,
"e": {"e1": 1},
"f": 1,
}
assert dict_0 == exp


def test_include_fields_no_fields():
source = {"a": 0}
res = utils.include_fields(source, fields=set())
assert res == source


def test_include_fields():
source = {"a": 0, "b": 0, "c": {"c1": 0, "c2": 0, "c3": 0}, "d": {"d1": 0}, "e": 0}
fields = {"a", "c.c1", "c.c2", "d", "f"}
res = utils.include_fields(source, fields=fields)
exp = {"a": 0, "c": {"c1": 0, "c2": 0}, "d": {"d1": 0}}
assert res == exp


def test_exclude_fields():
source = {
"a": 0,
"b": 0,
"c": {"c1": 0, "c2": 0},
"d": {"d1": 0},
"e": {"e1": 0},
"f": 0,
}
fields = {"a", "c.c1", "d.d1", "e"}
utils.exclude_fields(source, fields=fields)
exp = {"b": 0, "c": {"c2": 0}, "f": 0}
assert source == exp


def test_filter_fields_no_included_properties():
item = utils.Item(
id="test_id",
collection="test_collection",
properties={"prop_1": 0, "prop_2": 0},
)
res = utils.filter_fields(item, include={"missing_field"}, exclude=set())
exp = utils.Item(id="test_id", collection="test_collection")
assert res == exp


@pytest.mark.parametrize(
"include, exclude, exp",
[
({"field_a", "field_b"}, {"field_a"}, {"field_b": "b"}),
({"field_a"}, {"field_a"}, {}),
({"properties"}, {"properties.prop_1"}, {"properties": {"prop_2": 0}}),
({"properties.prop_1"}, {"properties"}, {}),
],
)
def test_filter_fields(include, exclude, exp):
source = utils.Item(
id="test_id",
collection="test_collection",
properties={"prop_1": 0, "prop_2": 0},
field_a="a",
field_b="b",
)
res = utils.filter_fields(source, include=include, exclude=exclude)
assert res == utils.Item(**exp)
Loading