Skip to content

Commit 95fc089

Browse files
committed
fix: clear error on disjoint instance IDs in render_shapes/labels/points
Closes #603. Signed-off-by: SAY-5 <say.apm35@gmail.com>
1 parent 53abe71 commit 95fc089

4 files changed

Lines changed: 123 additions & 0 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,27 @@ def _add_legend_and_colorbar(
311311
)
312312

313313

314+
def _check_instance_ids_overlap(
315+
sdata: sd.SpatialData,
316+
table_name: str,
317+
element_name: str,
318+
element_index: abc.Iterable[Any],
319+
) -> None:
320+
"""Raise a clear error when a table annotates an element but no instance IDs overlap (#603).
321+
322+
Without this guard the downstream join/lookup crashes with an opaque ``KeyError: None`` from
323+
deep inside pandas, giving no hint that the ``instance_key`` value sets are disjoint.
324+
"""
325+
_, region_key, instance_key = get_table_keys(sdata[table_name])
326+
annotating = sdata[table_name].obs[sdata[table_name].obs[region_key].isin([element_name])]
327+
if len(annotating) > 0 and set(annotating[instance_key]).isdisjoint(set(element_index)):
328+
raise ValueError(
329+
f"No instance IDs overlap between table '{table_name}' (instance_key='{instance_key}') "
330+
f"and element '{element_name}'. Check that the table's '{instance_key}' column matches the "
331+
f"element's index."
332+
)
333+
334+
314335
def _render_shapes(
315336
sdata: sd.SpatialData,
316337
render_params: ShapesRenderParams,
@@ -336,6 +357,10 @@ def _render_shapes(
336357
table = None
337358
shapes = sdata_filt[element]
338359
else:
360+
# Guard against a disjoint instance_key (#603) *before* mutating obs.index.name below,
361+
# so a failure here can never leave the index name in a half-restored state.
362+
_check_instance_ids_overlap(sdata, table_name, element, sdata[element].index)
363+
339364
# Workaround for upstream spatialdata bug (scverse/spatialdata#1099):
340365
# join_spatialelement_table calls table.obs.reset_index() which fails when
341366
# the obs index name matches an existing column (e.g. "EntityID" in Merfish data).
@@ -742,6 +767,10 @@ def _render_points(
742767

743768
added_color_from_table = False
744769
if col_for_color is not None and col_for_color not in points.columns:
770+
if table_name is not None:
771+
# Guard against disjoint instance IDs (#603): without this the table lookup below
772+
# crashes with an opaque `KeyError: None` instead of explaining the mismatch.
773+
_check_instance_ids_overlap(sdata, table_name, element, points.index)
745774
color_values = get_values(
746775
value_key=col_for_color,
747776
sdata=sdata_filt,
@@ -1651,6 +1680,7 @@ def _render_labels(
16511680
instance_id = np.unique(label)
16521681
table = None
16531682
else:
1683+
_check_instance_ids_overlap(sdata, table_name, element, np.unique(label.values))
16541684
_, region_key, instance_key = get_table_keys(sdata[table_name])
16551685
table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])]
16561686

tests/pl/test_render_labels.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,34 @@ def test_render_labels_rejects_float_dtype(dtype):
490490
sdata.pl.render_labels("lbl").pl.show(ax=ax)
491491
finally:
492492
plt.close(fig)
493+
494+
495+
def test_render_labels_disjoint_instance_ids_clear_error():
496+
# Regression test for #603: when a table annotates the label (region key matches) but no
497+
# instance_id values overlap with the label IDs, the call used to crash with a bare
498+
# `KeyError: None` from deep inside spatialdata. Replace with a clear ValueError.
499+
arr = np.zeros((20, 20), dtype=np.int32)
500+
arr[3:8, 3:8] = 1
501+
arr[12:17, 12:17] = 2
502+
obs = pd.DataFrame(
503+
{
504+
"instance_id": [99, 100], # label has IDs 1, 2 -- no overlap
505+
"region": pd.Categorical(["lbl"] * 2),
506+
"cat": pd.Categorical(["A", "B"]),
507+
}
508+
)
509+
obs.index = obs.index.astype(str)
510+
table = TableModel.parse(
511+
AnnData(X=np.zeros((2, 1)), obs=obs),
512+
region=["lbl"],
513+
region_key="region",
514+
instance_key="instance_id",
515+
)
516+
sdata = SpatialData(labels={"lbl": Labels2DModel.parse(arr, dims=["y", "x"])}, tables={"t": table})
517+
518+
fig, ax = plt.subplots()
519+
try:
520+
with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 'lbl'"):
521+
sdata.pl.render_labels("lbl", color="cat", table_name="t").pl.show(ax=ax)
522+
finally:
523+
plt.close(fig)

tests/pl/test_render_points.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,3 +1004,32 @@ def test_no_table_fallback_warning_for_element_column(caplog):
10041004
with logger_no_warns(caplog, logger, match="fallback for color mapping"):
10051005
sdata.pl.render_points("pts", color="cell_type").pl.show()
10061006
plt.close("all")
1007+
1008+
1009+
def test_render_points_disjoint_instance_ids_clear_error():
1010+
# Regression test for #603: when a table annotates the points (region key matches) but no
1011+
# instance_id values overlap with the points index, the table lookup used to crash with a
1012+
# bare `KeyError: None` from deep inside spatialdata. Replace with a clear ValueError.
1013+
points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))
1014+
obs = pd.DataFrame(
1015+
{
1016+
"instance_id": [99, 100, 101], # points index is 0, 1, 2 -- no overlap
1017+
"region": pd.Categorical(["pts"] * 3),
1018+
"cat": pd.Categorical(["A", "B", "C"]),
1019+
}
1020+
)
1021+
obs.index = obs.index.astype(str)
1022+
table = TableModel.parse(
1023+
AnnData(X=np.zeros((3, 1)), obs=obs),
1024+
region=["pts"],
1025+
region_key="region",
1026+
instance_key="instance_id",
1027+
)
1028+
sdata = SpatialData(points={"pts": points}, tables={"t": table})
1029+
1030+
fig, ax = plt.subplots()
1031+
try:
1032+
with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 'pts'"):
1033+
sdata.pl.render_points("pts", color="cat", table_name="t").pl.show(ax=ax)
1034+
finally:
1035+
plt.close(fig)

tests/pl/test_render_shapes.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,39 @@ def test_render_shapes_color_with_conflicting_index_name():
13231323
sdata.pl.render_shapes("shapes", color="cell_type", table_name="table").pl.show()
13241324

13251325

1326+
def test_render_shapes_disjoint_instance_ids_clear_error():
1327+
# Regression test for #603: when a table annotates the element (region key matches) but no
1328+
# instance_id values overlap, the call used to crash with a bare `KeyError: None` from deep
1329+
# inside spatialdata's join path. Replace with a clear, actionable ValueError.
1330+
from shapely.geometry import Point
1331+
1332+
shapes = ShapesModel.parse(
1333+
gpd.GeoDataFrame({"geometry": [Point(5, 5), Point(15, 5), Point(25, 5)], "radius": [2.0] * 3})
1334+
)
1335+
obs = pd.DataFrame(
1336+
{
1337+
"instance_id": [99, 100, 101], # element has IDs 0, 1, 2 -- no overlap
1338+
"region": pd.Categorical(["s"] * 3),
1339+
"cat": pd.Categorical(["A", "B", "C"]),
1340+
}
1341+
)
1342+
obs.index = obs.index.astype(str)
1343+
table = TableModel.parse(
1344+
AnnData(X=np.zeros((3, 1)), obs=obs),
1345+
region=["s"],
1346+
region_key="region",
1347+
instance_key="instance_id",
1348+
)
1349+
sdata = SpatialData(shapes={"s": shapes}, tables={"t": table})
1350+
1351+
fig, ax = plt.subplots()
1352+
try:
1353+
with pytest.raises(ValueError, match=r"No instance IDs overlap.*table 't'.*element 's'"):
1354+
sdata.pl.render_shapes("s", color="cat", table_name="t").pl.show(ax=ax)
1355+
finally:
1356+
plt.close(fig)
1357+
1358+
13261359
def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData):
13271360
"""Datashader colorbar range must not exceed the actual data range for shapes.
13281361

0 commit comments

Comments
 (0)