Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
from spatialdata._core.query.relational_query import _locate_value
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, get_table_keys
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import Scale
from xarray import DataArray, DataTree
Expand Down Expand Up @@ -794,6 +794,14 @@ def _get_colors_for_categorical_obs(
return palette[:len_cat] # type: ignore[return-value]


def _format_element_names(element_name: list[str] | str | None) -> str:
if element_name is None:
return "the requested element"
if isinstance(element_name, str):
return f"'{element_name}'"
return ", ".join(f"'{name}'" for name in element_name)


def _format_element_name(element_name: list[str] | str | None) -> str:
if isinstance(element_name, str):
return element_name
Expand All @@ -802,6 +810,86 @@ def _format_element_name(element_name: list[str] | str | None) -> str:
return "<unknown>"


def _preview_values(values: Sequence[Any], limit: int = 5) -> str:
values = list(values)
preview = ", ".join(map(str, values[:limit]))
if len(values) > limit:
preview += ", ..."
return preview


def _ensure_one_to_one_mapping(
sdata: SpatialData,
element: SpatialElement | None,
element_name: list[str] | str | None,
table_name: str | None,
) -> None:
if table_name is None or element_name is None:
return

table = sdata.get(table_name, None)
if table is None:
return

_validate_table_instance_uniqueness(table, element_name, table_name)
_validate_shape_index_uniqueness(element, element_name, table_name)


def _validate_shape_index_uniqueness(
element: SpatialElement | None,
element_name: list[str] | str | None,
table_name: str,
) -> None:
if not isinstance(element, GeoDataFrame):
return

duplicates = element.index[element.index.duplicated(keep=False)]
if duplicates.empty:
return

element_label = _format_element_names(element_name)
preview = _preview_values(pd.Index(duplicates).unique())
raise ValueError(
f"{element_label} contains duplicate index values ({preview}) while table '{table_name}' "
"requires a one-to-one mapping between shapes and annotations. "
"Please ensure each spatial element has a unique index."
)


def _validate_table_instance_uniqueness(
table: AnnData,
element_name: list[str] | str | None,
table_name: str,
) -> None:
try:
_, region_key, instance_key = get_table_keys(table)
except (AttributeError, KeyError, ValueError):
return

if instance_key is None or instance_key not in table.obs.columns:
return

obs = table.obs
if region_key is not None and region_key in obs.columns and element_name is not None:
element_names = [element_name] if isinstance(element_name, str) else list(element_name)
obs = obs[obs[region_key].isin(element_names)]

if obs.empty:
return

duplicates_mask = obs[instance_key].duplicated(keep=False)
if not duplicates_mask.any():
return

element_label = _format_element_names(element_name)
preview = _preview_values(obs.loc[duplicates_mask, instance_key].astype(str).unique())
raise ValueError(
f"Table '{table_name}' contains duplicate '{instance_key}' values for {element_label}: {preview}. "
"Each observation must annotate a single spatial element. Please deduplicate the table or subset it "
"before plotting."
)


def _infer_color_data_kind(
series: pd.Series,
value_to_plot: str,
Expand Down Expand Up @@ -885,6 +973,13 @@ def _set_color_source_vec(
)

if len(origins) == 1 and value_to_plot is not None:
if table_name is not None:
_ensure_one_to_one_mapping(
sdata=sdata,
element=element,
element_name=element_name,
table_name=table_name,
)
color_source_vector = get_values(
value_key=value_to_plot,
sdata=sdata,
Expand Down
49 changes: 49 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,55 @@ def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData):
norm = Normalize(vmin=0, vmax=5, clip=True)
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=norm).pl.show()

def test_render_shapes_duplicate_shape_indices_error(self, sdata_blobs: SpatialData):
element = "blobs_polygons"
shapes = sdata_blobs.shapes[element].copy()
n_shapes = len(shapes)
rng = get_standard_RNG()
adata = AnnData(rng.normal(size=(n_shapes, 3)))
adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes)
adata.obs["instance_id"] = [f"id_{i}" for i in range(n_shapes)]
adata.obs["region"] = pd.Categorical([element] * n_shapes)
table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id")
sdata_blobs["table"] = table
instance_key = table.uns["spatialdata_attrs"]["instance_key"]
shapes.index = table.obs[instance_key].tolist()
duplicated_index = shapes.index.to_list()
duplicated_index[1] = duplicated_index[0]
shapes.index = duplicated_index
sdata_blobs.shapes[element] = shapes

with pytest.raises(ValueError, match="duplicate index values"):
sdata_blobs.pl.render_shapes(
element=element,
color="annotation",
table_name="table",
).pl.show()

def test_render_shapes_duplicate_table_rows_error(self, sdata_blobs: SpatialData):
element = "blobs_polygons"
shapes = sdata_blobs.shapes[element]
n_shapes = len(shapes)
rng = get_standard_RNG()
shape_ids = [f"shape_{i}" for i in range(n_shapes)]
shapes.index = shape_ids
sdata_blobs.shapes[element] = shapes
adata = AnnData(rng.normal(size=(n_shapes, 3)))
adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes)
adata.obs["instance_id"] = shape_ids
adata.obs["region"] = pd.Categorical([element] * n_shapes)
table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id")
instance_key = table.uns["spatialdata_attrs"]["instance_key"]
table.obs.at[table.obs.index[1], instance_key] = table.obs.at[table.obs.index[0], instance_key]
sdata_blobs["table"] = table

with pytest.raises(ValueError, match="duplicate 'instance"):
sdata_blobs.pl.render_shapes(
element=element,
color="annotation",
table_name="table",
).pl.show()

def test_render_shapes_raises_when_color_key_missing(self, sdata_blobs_shapes_annotated: SpatialData):
missing_col = "__non_existent_column__"
with pytest.raises(KeyError, match=f"Unable to locate color key '{missing_col}'"):
Expand Down