Skip to content

Commit 35b5e20

Browse files
authored
Now raising an error if color or table doesn't exist (#515)
1 parent 17df374 commit 35b5e20

9 files changed

+346
-156
lines changed

src/spatialdata_plot/_logging.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# from https://github.com/scverse/spatialdata/blob/main/src/spatialdata/_logging.py
22

33
import logging
4+
import re
5+
from collections.abc import Iterator
6+
from contextlib import contextmanager
7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING: # pragma: no cover
10+
from _pytest.logging import LogCaptureFixture
411

512

613
def _setup_logger() -> "logging.Logger":
@@ -21,3 +28,44 @@ def _setup_logger() -> "logging.Logger":
2128

2229

2330
logger = _setup_logger()
31+
32+
33+
@contextmanager
34+
def logger_warns(
35+
caplog: "LogCaptureFixture",
36+
logger: logging.Logger,
37+
match: str | None = None,
38+
level: int = logging.WARNING,
39+
) -> Iterator[None]:
40+
"""
41+
Context manager similar to pytest.warns, but for logging.Logger.
42+
43+
Usage:
44+
with logger_warns(caplog, logger, match="Found 1 NaN"):
45+
call_code_that_logs()
46+
"""
47+
# Store initial record count to only check new records
48+
initial_record_count = len(caplog.records)
49+
50+
# Add caplog's handler directly to the logger to capture logs even if propagate=False
51+
handler = caplog.handler
52+
logger.addHandler(handler)
53+
original_level = logger.level
54+
logger.setLevel(level)
55+
56+
# Use caplog.at_level to ensure proper capture setup
57+
with caplog.at_level(level, logger=logger.name):
58+
try:
59+
yield
60+
finally:
61+
logger.removeHandler(handler)
62+
logger.setLevel(original_level)
63+
64+
# Only check records that were added during this context
65+
records = [r for r in caplog.records[initial_record_count:] if r.levelno >= level]
66+
67+
if match is not None:
68+
pattern = re.compile(match)
69+
if not any(pattern.search(r.getMessage()) for r in records):
70+
msgs = [r.getMessage() for r in records]
71+
raise AssertionError(f"Did not find log matching {match!r} in records: {msgs!r}")

src/spatialdata_plot/pl/render.py

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import warnings
43
from collections import abc
54
from copy import copy
65

@@ -49,7 +48,6 @@
4948
_get_extent_and_range_for_datashader_canvas,
5049
_get_linear_colormap,
5150
_hex_no_alpha,
52-
_is_coercable_to_float,
5351
_map_color_seg,
5452
_maybe_set_colors,
5553
_mpl_ax_contains_elements,
@@ -94,20 +92,7 @@ def _render_shapes(
9492
)
9593
sdata_filt[table_name] = table = joined_table
9694

97-
if (
98-
col_for_color is not None
99-
and table_name is not None
100-
and col_for_color in sdata_filt[table_name].obs.columns
101-
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
102-
and not _is_coercable_to_float(color_col)
103-
):
104-
warnings.warn(
105-
f"Converting copy of '{col_for_color}' column to categorical dtype for categorical plotting. "
106-
f"Consider converting before plotting.",
107-
UserWarning,
108-
stacklevel=2,
109-
)
110-
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")
95+
shapes = sdata_filt[element]
11196

11297
# get color vector (categorical or continuous)
11398
color_source_vector, color_vector, _ = _set_color_source_vec(
@@ -121,6 +106,7 @@ def _render_shapes(
121106
cmap_params=render_params.cmap_params,
122107
table_name=table_name,
123108
table_layer=table_layer,
109+
coordinate_system=coordinate_system,
124110
)
125111

126112
values_are_categorical = color_source_vector is not None
@@ -144,12 +130,25 @@ def _render_shapes(
144130

145131
# continuous case: leave NaNs as NaNs; utils maps them to na_color during draw
146132
if color_source_vector is None and not values_are_categorical:
147-
color_vector = np.asarray(color_vector, dtype=float)
148-
if np.isnan(color_vector).any():
149-
nan_count = int(np.isnan(color_vector).sum())
150-
msg = f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'."
151-
warnings.warn(msg, UserWarning, stacklevel=2)
152-
logger.warning(msg)
133+
_series = color_vector if isinstance(color_vector, pd.Series) else pd.Series(color_vector)
134+
135+
try:
136+
color_vector = np.asarray(_series, dtype=float)
137+
except (TypeError, ValueError):
138+
nan_count = int(_series.isna().sum())
139+
if nan_count:
140+
logger.warning(
141+
f"Found {nan_count} NaN values in color data. "
142+
"These observations will be colored with the 'na_color'."
143+
)
144+
color_vector = _series.to_numpy()
145+
else:
146+
if np.isnan(color_vector).any():
147+
nan_count = int(np.isnan(color_vector).sum())
148+
logger.warning(
149+
f"Found {nan_count} NaN values in color data. "
150+
"These observations will be colored with the 'na_color'."
151+
)
153152

154153
# Using dict.fromkeys here since set returns in arbitrary order
155154
# remove the color of NaN values, else it might be assigned to a category
@@ -476,10 +475,33 @@ def _render_shapes(
476475
if not values_are_categorical:
477476
vmin = render_params.cmap_params.norm.vmin
478477
vmax = render_params.cmap_params.norm.vmax
479-
if vmin is None:
480-
vmin = float(np.nanmin(color_vector))
481-
if vmax is None:
482-
vmax = float(np.nanmax(color_vector))
478+
if vmin is None or vmax is None:
479+
# Extract numeric values only (filter out strings and other non-numeric types)
480+
if isinstance(color_vector, np.ndarray):
481+
if np.issubdtype(color_vector.dtype, np.number):
482+
# Already numeric, can use directly
483+
numeric_values = color_vector
484+
else:
485+
# Mixed types - extract only numeric values using pandas
486+
numeric_values = pd.to_numeric(color_vector, errors="coerce")
487+
numeric_values = numeric_values[np.isfinite(numeric_values)]
488+
if len(numeric_values) > 0:
489+
if vmin is None:
490+
vmin = float(np.nanmin(numeric_values))
491+
if vmax is None:
492+
vmax = float(np.nanmax(numeric_values))
493+
else:
494+
# No numeric values found, use defaults
495+
if vmin is None:
496+
vmin = 0.0
497+
if vmax is None:
498+
vmax = 1.0
499+
else:
500+
# Not a numpy array, use defaults
501+
if vmin is None:
502+
vmin = 0.0
503+
if vmax is None:
504+
vmax = 1.0
483505
_cax.set_clim(vmin=vmin, vmax=vmax)
484506

485507
if (
@@ -541,31 +563,16 @@ def _render_points(
541563
coords = ["x", "y"]
542564

543565
if table_name is not None and col_for_color not in points.columns:
544-
warnings.warn(
566+
logger.warning(
545567
f"Annotating points with {col_for_color} which is stored in the table `{table_name}`. "
546-
f"To improve performance, it is advisable to store point annotations directly in the .parquet file.",
547-
UserWarning,
548-
stacklevel=2,
568+
f"To improve performance, it is advisable to store point annotations directly in the .parquet file."
549569
)
550570

551571
if col_for_color is None or (
552572
table_name is not None
553573
and (col_for_color in sdata_filt[table_name].obs.columns or col_for_color in sdata_filt[table_name].var_names)
554574
):
555575
points = points[coords].compute()
556-
if (
557-
col_for_color
558-
and col_for_color in sdata_filt[table_name].obs.columns
559-
and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O"
560-
and not _is_coercable_to_float(color_col)
561-
):
562-
warnings.warn(
563-
f"Converting copy of '{col_for_color}' column to categorical dtype for categorical "
564-
f"plotting. Consider converting before plotting.",
565-
UserWarning,
566-
stacklevel=2,
567-
)
568-
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")
569576
else:
570577
coords += [col_for_color]
571578
points = points[coords].compute()
@@ -683,6 +690,7 @@ def _render_points(
683690
alpha=render_params.alpha,
684691
table_name=table_name,
685692
render_type="points",
693+
coordinate_system=coordinate_system,
686694
)
687695

688696
if added_color_from_table and col_for_color is not None:
@@ -1219,6 +1227,7 @@ def _render_labels(
12191227
cmap_params=render_params.cmap_params,
12201228
table_name=table_name,
12211229
table_layer=table_layer,
1230+
coordinate_system=coordinate_system,
12221231
)
12231232

12241233
# rasterize could have removed labels from label

0 commit comments

Comments
 (0)