Skip to content

Commit 927f689

Browse files
authored
Safeguard against invalid user input in render_shapes (#512)
1 parent 64aeb73 commit 927f689

File tree

2 files changed

+145
-1
lines changed

2 files changed

+145
-1
lines changed

src/spatialdata_plot/pl/utils.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
)
6666
from spatialdata._core.query.relational_query import _locate_value
6767
from spatialdata._types import ArrayLike
68-
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement
68+
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, get_table_keys
6969
from spatialdata.transformations.operations import get_transformation
7070
from spatialdata.transformations.transformations import Scale
7171
from xarray import DataArray, DataTree
@@ -794,6 +794,14 @@ def _get_colors_for_categorical_obs(
794794
return palette[:len_cat] # type: ignore[return-value]
795795

796796

797+
def _format_element_names(element_name: list[str] | str | None) -> str:
798+
if element_name is None:
799+
return "the requested element"
800+
if isinstance(element_name, str):
801+
return f"'{element_name}'"
802+
return ", ".join(f"'{name}'" for name in element_name)
803+
804+
797805
def _format_element_name(element_name: list[str] | str | None) -> str:
798806
if isinstance(element_name, str):
799807
return element_name
@@ -802,6 +810,86 @@ def _format_element_name(element_name: list[str] | str | None) -> str:
802810
return "<unknown>"
803811

804812

813+
def _preview_values(values: Sequence[Any], limit: int = 5) -> str:
814+
values = list(values)
815+
preview = ", ".join(map(str, values[:limit]))
816+
if len(values) > limit:
817+
preview += ", ..."
818+
return preview
819+
820+
821+
def _ensure_one_to_one_mapping(
822+
sdata: SpatialData,
823+
element: SpatialElement | None,
824+
element_name: list[str] | str | None,
825+
table_name: str | None,
826+
) -> None:
827+
if table_name is None or element_name is None:
828+
return
829+
830+
table = sdata.get(table_name, None)
831+
if table is None:
832+
return
833+
834+
_validate_table_instance_uniqueness(table, element_name, table_name)
835+
_validate_shape_index_uniqueness(element, element_name, table_name)
836+
837+
838+
def _validate_shape_index_uniqueness(
839+
element: SpatialElement | None,
840+
element_name: list[str] | str | None,
841+
table_name: str,
842+
) -> None:
843+
if not isinstance(element, GeoDataFrame):
844+
return
845+
846+
duplicates = element.index[element.index.duplicated(keep=False)]
847+
if duplicates.empty:
848+
return
849+
850+
element_label = _format_element_names(element_name)
851+
preview = _preview_values(pd.Index(duplicates).unique())
852+
raise ValueError(
853+
f"{element_label} contains duplicate index values ({preview}) while table '{table_name}' "
854+
"requires a one-to-one mapping between shapes and annotations. "
855+
"Please ensure each spatial element has a unique index."
856+
)
857+
858+
859+
def _validate_table_instance_uniqueness(
860+
table: AnnData,
861+
element_name: list[str] | str | None,
862+
table_name: str,
863+
) -> None:
864+
try:
865+
_, region_key, instance_key = get_table_keys(table)
866+
except (AttributeError, KeyError, ValueError):
867+
return
868+
869+
if instance_key is None or instance_key not in table.obs.columns:
870+
return
871+
872+
obs = table.obs
873+
if region_key is not None and region_key in obs.columns and element_name is not None:
874+
element_names = [element_name] if isinstance(element_name, str) else list(element_name)
875+
obs = obs[obs[region_key].isin(element_names)]
876+
877+
if obs.empty:
878+
return
879+
880+
duplicates_mask = obs[instance_key].duplicated(keep=False)
881+
if not duplicates_mask.any():
882+
return
883+
884+
element_label = _format_element_names(element_name)
885+
preview = _preview_values(obs.loc[duplicates_mask, instance_key].astype(str).unique())
886+
raise ValueError(
887+
f"Table '{table_name}' contains duplicate '{instance_key}' values for {element_label}: {preview}. "
888+
"Each observation must annotate a single spatial element. Please deduplicate the table or subset it "
889+
"before plotting."
890+
)
891+
892+
805893
def _infer_color_data_kind(
806894
series: pd.Series,
807895
value_to_plot: str,
@@ -885,6 +973,13 @@ def _set_color_source_vec(
885973
)
886974

887975
if len(origins) == 1 and value_to_plot is not None:
976+
if table_name is not None:
977+
_ensure_one_to_one_mapping(
978+
sdata=sdata,
979+
element=element,
980+
element_name=element_name,
981+
table_name=table_name,
982+
)
888983
color_source_vector = get_values(
889984
value_key=value_to_plot,
890985
sdata=sdata,

tests/pl/test_render_shapes.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,55 @@ def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData):
209209
norm = Normalize(vmin=0, vmax=5, clip=True)
210210
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=norm).pl.show()
211211

212+
def test_render_shapes_duplicate_shape_indices_error(self, sdata_blobs: SpatialData):
213+
element = "blobs_polygons"
214+
shapes = sdata_blobs.shapes[element].copy()
215+
n_shapes = len(shapes)
216+
rng = get_standard_RNG()
217+
adata = AnnData(rng.normal(size=(n_shapes, 3)))
218+
adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes)
219+
adata.obs["instance_id"] = [f"id_{i}" for i in range(n_shapes)]
220+
adata.obs["region"] = pd.Categorical([element] * n_shapes)
221+
table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id")
222+
sdata_blobs["table"] = table
223+
instance_key = table.uns["spatialdata_attrs"]["instance_key"]
224+
shapes.index = table.obs[instance_key].tolist()
225+
duplicated_index = shapes.index.to_list()
226+
duplicated_index[1] = duplicated_index[0]
227+
shapes.index = duplicated_index
228+
sdata_blobs.shapes[element] = shapes
229+
230+
with pytest.raises(ValueError, match="duplicate index values"):
231+
sdata_blobs.pl.render_shapes(
232+
element=element,
233+
color="annotation",
234+
table_name="table",
235+
).pl.show()
236+
237+
def test_render_shapes_duplicate_table_rows_error(self, sdata_blobs: SpatialData):
238+
element = "blobs_polygons"
239+
shapes = sdata_blobs.shapes[element]
240+
n_shapes = len(shapes)
241+
rng = get_standard_RNG()
242+
shape_ids = [f"shape_{i}" for i in range(n_shapes)]
243+
shapes.index = shape_ids
244+
sdata_blobs.shapes[element] = shapes
245+
adata = AnnData(rng.normal(size=(n_shapes, 3)))
246+
adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes)
247+
adata.obs["instance_id"] = shape_ids
248+
adata.obs["region"] = pd.Categorical([element] * n_shapes)
249+
table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id")
250+
instance_key = table.uns["spatialdata_attrs"]["instance_key"]
251+
table.obs.at[table.obs.index[1], instance_key] = table.obs.at[table.obs.index[0], instance_key]
252+
sdata_blobs["table"] = table
253+
254+
with pytest.raises(ValueError, match="duplicate 'instance"):
255+
sdata_blobs.pl.render_shapes(
256+
element=element,
257+
color="annotation",
258+
table_name="table",
259+
).pl.show()
260+
212261
def test_render_shapes_raises_when_color_key_missing(self, sdata_blobs_shapes_annotated: SpatialData):
213262
missing_col = "__non_existent_column__"
214263
with pytest.raises(KeyError, match=f"Unable to locate color key '{missing_col}'"):

0 commit comments

Comments
 (0)