Skip to content

Commit 2e33801

Browse files
committed
fix: clear error on disjoint instance IDs in render_shapes/labels/points
1 parent 7e6a607 commit 2e33801

4 files changed

Lines changed: 110 additions & 3 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,23 @@ def _add_legend_and_colorbar(
311311
)
312312

313313

314+
def _check_instance_ids_overlap(
315+
sdata: sd.SpatialData,
316+
table_name: str,
317+
element_name: str,
318+
element_index: abc.Iterable[Any],
319+
) -> None:
320+
"""Raise a clear error when a table annotates an element but no instance IDs overlap (#603)."""
321+
_, region_key, instance_key = get_table_keys(sdata[table_name])
322+
annotating = sdata[table_name].obs[sdata[table_name].obs[region_key].isin([element_name])]
323+
if len(annotating) > 0 and set(annotating[instance_key]).isdisjoint(set(element_index)):
324+
raise ValueError(
325+
f"No instance IDs overlap between table '{table_name}' (instance_key='{instance_key}') "
326+
f"and element '{element_name}'. Check that the table's '{instance_key}' column matches the "
327+
f"element's index."
328+
)
329+
330+
314331
def _render_shapes(
315332
sdata: sd.SpatialData,
316333
render_params: ShapesRenderParams,
@@ -336,6 +353,9 @@ def _render_shapes(
336353
table = None
337354
shapes = sdata_filt[element]
338355
else:
356+
# check before mutating obs.index.name below so a failure leaves no half-restored state
357+
_check_instance_ids_overlap(sdata, table_name, element, sdata_filt[element].index)
358+
339359
# Workaround for upstream spatialdata bug (scverse/spatialdata#1099):
340360
# join_spatialelement_table calls table.obs.reset_index() which fails when
341361
# the obs index name matches an existing column (e.g. "EntityID" in Merfish data).
@@ -742,6 +762,9 @@ def _render_points(
742762

743763
added_color_from_table = False
744764
if col_for_color is not None and col_for_color not in points.columns:
765+
if table_name is not None:
766+
# guard against disjoint instance IDs (#603) for a clearer error than KeyError: None
767+
_check_instance_ids_overlap(sdata, table_name, element, points.index)
745768
color_values = get_values(
746769
value_key=col_for_color,
747770
sdata=sdata_filt,
@@ -1651,6 +1674,7 @@ def _render_labels(
16511674
instance_id = np.unique(label)
16521675
table = None
16531676
else:
1677+
_check_instance_ids_overlap(sdata, table_name, element, np.unique(label.values))
16541678
_, region_key, instance_key = get_table_keys(sdata[table_name])
16551679
table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])]
16561680

tests/pl/test_render_labels.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,3 +521,32 @@ def test_render_labels_rejects_background_instance_id_in_table():
521521
sdata.pl.render_labels("lbl", color="score", table_name="t").pl.show(ax=ax)
522522
finally:
523523
plt.close(fig)
524+
525+
526+
def test_render_labels_disjoint_instance_ids_clear_error():
527+
# regression test for #603: disjoint instance_id values must raise a clear ValueError
528+
arr = np.zeros((20, 20), dtype=np.int32)
529+
arr[3:8, 3:8] = 1
530+
arr[12:17, 12:17] = 2
531+
obs = pd.DataFrame(
532+
{
533+
"instance_id": [99, 100], # label has IDs 1, 2 (no overlap)
534+
"region": pd.Categorical(["lbl"] * 2),
535+
"cat": pd.Categorical(["A", "B"]),
536+
}
537+
)
538+
obs.index = obs.index.astype(str)
539+
table = TableModel.parse(
540+
AnnData(X=np.zeros((2, 1)), obs=obs),
541+
region=["lbl"],
542+
region_key="region",
543+
instance_key="instance_id",
544+
)
545+
sdata = SpatialData(labels={"lbl": Labels2DModel.parse(arr, dims=["y", "x"])}, tables={"t": table})
546+
547+
fig, ax = plt.subplots()
548+
try:
549+
with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 'lbl'"):
550+
sdata.pl.render_labels("lbl", color="cat", table_name="t").pl.show(ax=ax)
551+
finally:
552+
plt.close(fig)

tests/pl/test_render_points.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,3 +1004,30 @@ def test_no_table_fallback_warning_for_element_column(caplog):
10041004
with logger_no_warns(caplog, logger, match="fallback for color mapping"):
10051005
sdata.pl.render_points("pts", color="cell_type").pl.show()
10061006
plt.close("all")
1007+
1008+
1009+
def test_render_points_disjoint_instance_ids_clear_error():
1010+
# regression test for #603: disjoint instance_id values must raise a clear ValueError
1011+
points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))
1012+
obs = pd.DataFrame(
1013+
{
1014+
"instance_id": [99, 100, 101], # points index is 0, 1, 2 (no overlap)
1015+
"region": pd.Categorical(["pts"] * 3),
1016+
"cat": pd.Categorical(["A", "B", "C"]),
1017+
}
1018+
)
1019+
obs.index = obs.index.astype(str)
1020+
table = TableModel.parse(
1021+
AnnData(X=np.zeros((3, 1)), obs=obs),
1022+
region=["pts"],
1023+
region_key="region",
1024+
instance_key="instance_id",
1025+
)
1026+
sdata = SpatialData(points={"pts": points}, tables={"t": table})
1027+
1028+
fig, ax = plt.subplots()
1029+
try:
1030+
with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 'pts'"):
1031+
sdata.pl.render_points("pts", color="cat", table_name="t").pl.show(ax=ax)
1032+
finally:
1033+
plt.close(fig)

tests/pl/test_render_shapes.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,9 +1076,7 @@ def test_gene_symbols_missing_column_raises_auto_detect(sdata_blobs: SpatialData
10761076
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
10771077
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
10781078
with pytest.raises(KeyError, match="`gene_symbols=`"):
1079-
sdata_blobs.pl.render_shapes(
1080-
"blobs_circles", color="GeneA", gene_symbols="WRONGCOL"
1081-
).pl.show()
1079+
sdata_blobs.pl.render_shapes("blobs_circles", color="GeneA", gene_symbols="WRONGCOL").pl.show()
10821080

10831081

10841082
def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
@@ -1351,6 +1349,35 @@ def test_render_shapes_color_with_conflicting_index_name():
13511349
sdata.pl.render_shapes("shapes", color="cell_type", table_name="table").pl.show()
13521350

13531351

1352+
def test_render_shapes_disjoint_instance_ids_clear_error():
1353+
# regression test for #603: disjoint instance_id values must raise a clear ValueError
1354+
shapes = ShapesModel.parse(
1355+
gpd.GeoDataFrame({"geometry": [Point(5, 5), Point(15, 5), Point(25, 5)], "radius": [2.0] * 3})
1356+
)
1357+
obs = pd.DataFrame(
1358+
{
1359+
"instance_id": [99, 100, 101], # element has IDs 0, 1, 2 (no overlap)
1360+
"region": pd.Categorical(["s"] * 3),
1361+
"cat": pd.Categorical(["A", "B", "C"]),
1362+
}
1363+
)
1364+
obs.index = obs.index.astype(str)
1365+
table = TableModel.parse(
1366+
AnnData(X=np.zeros((3, 1)), obs=obs),
1367+
region=["s"],
1368+
region_key="region",
1369+
instance_key="instance_id",
1370+
)
1371+
sdata = SpatialData(shapes={"s": shapes}, tables={"t": table})
1372+
1373+
fig, ax = plt.subplots()
1374+
try:
1375+
with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 's'"):
1376+
sdata.pl.render_shapes("s", color="cat", table_name="t").pl.show(ax=ax)
1377+
finally:
1378+
plt.close(fig)
1379+
1380+
13541381
def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData):
13551382
"""Datashader colorbar range must not exceed the actual data range for shapes.
13561383

0 commit comments

Comments
 (0)