11from __future__ import annotations
22
3- import warnings
43from collections import abc
54from copy import copy
65
4948 _get_extent_and_range_for_datashader_canvas ,
5049 _get_linear_colormap ,
5150 _hex_no_alpha ,
52- _is_coercable_to_float ,
5351 _map_color_seg ,
5452 _maybe_set_colors ,
5553 _mpl_ax_contains_elements ,
@@ -94,20 +92,7 @@ def _render_shapes(
9492 )
9593 sdata_filt [table_name ] = table = joined_table
9694
97- if (
98- col_for_color is not None
99- and table_name is not None
100- and col_for_color in sdata_filt [table_name ].obs .columns
101- and (color_col := sdata_filt [table_name ].obs [col_for_color ]).dtype == "O"
102- and not _is_coercable_to_float (color_col )
103- ):
104- warnings .warn (
105- f"Converting copy of '{ col_for_color } ' column to categorical dtype for categorical plotting. "
106- f"Consider converting before plotting." ,
107- UserWarning ,
108- stacklevel = 2 ,
109- )
110- sdata_filt [table_name ].obs [col_for_color ] = sdata_filt [table_name ].obs [col_for_color ].astype ("category" )
95+ shapes = sdata_filt [element ]
11196
11297 # get color vector (categorical or continuous)
11398 color_source_vector , color_vector , _ = _set_color_source_vec (
@@ -121,6 +106,7 @@ def _render_shapes(
121106 cmap_params = render_params .cmap_params ,
122107 table_name = table_name ,
123108 table_layer = table_layer ,
109+ coordinate_system = coordinate_system ,
124110 )
125111
126112 values_are_categorical = color_source_vector is not None
@@ -144,12 +130,25 @@ def _render_shapes(
144130
145131 # continuous case: leave NaNs as NaNs; utils maps them to na_color during draw
146132 if color_source_vector is None and not values_are_categorical :
147- color_vector = np .asarray (color_vector , dtype = float )
148- if np .isnan (color_vector ).any ():
149- nan_count = int (np .isnan (color_vector ).sum ())
150- msg = f"Found { nan_count } NaN values in color data. These observations will be colored with the 'na_color'."
151- warnings .warn (msg , UserWarning , stacklevel = 2 )
152- logger .warning (msg )
133+ _series = color_vector if isinstance (color_vector , pd .Series ) else pd .Series (color_vector )
134+
135+ try :
136+ color_vector = np .asarray (_series , dtype = float )
137+ except (TypeError , ValueError ):
138+ nan_count = int (_series .isna ().sum ())
139+ if nan_count :
140+ logger .warning (
141+ f"Found { nan_count } NaN values in color data. "
142+ "These observations will be colored with the 'na_color'."
143+ )
144+ color_vector = _series .to_numpy ()
145+ else :
146+ if np .isnan (color_vector ).any ():
147+ nan_count = int (np .isnan (color_vector ).sum ())
148+ logger .warning (
149+ f"Found { nan_count } NaN values in color data. "
150+ "These observations will be colored with the 'na_color'."
151+ )
153152
154153 # Using dict.fromkeys here since set returns in arbitrary order
155154 # remove the color of NaN values, else it might be assigned to a category
@@ -476,10 +475,33 @@ def _render_shapes(
476475 if not values_are_categorical :
477476 vmin = render_params .cmap_params .norm .vmin
478477 vmax = render_params .cmap_params .norm .vmax
479- if vmin is None :
480- vmin = float (np .nanmin (color_vector ))
481- if vmax is None :
482- vmax = float (np .nanmax (color_vector ))
478+ if vmin is None or vmax is None :
479+ # Extract numeric values only (filter out strings and other non-numeric types)
480+ if isinstance (color_vector , np .ndarray ):
481+ if np .issubdtype (color_vector .dtype , np .number ):
482+ # Already numeric, can use directly
483+ numeric_values = color_vector
484+ else :
485+ # Mixed types - extract only numeric values using pandas
486+ numeric_values = pd .to_numeric (color_vector , errors = "coerce" )
487+ numeric_values = numeric_values [np .isfinite (numeric_values )]
488+ if len (numeric_values ) > 0 :
489+ if vmin is None :
490+ vmin = float (np .nanmin (numeric_values ))
491+ if vmax is None :
492+ vmax = float (np .nanmax (numeric_values ))
493+ else :
494+ # No numeric values found, use defaults
495+ if vmin is None :
496+ vmin = 0.0
497+ if vmax is None :
498+ vmax = 1.0
499+ else :
500+ # Not a numpy array, use defaults
501+ if vmin is None :
502+ vmin = 0.0
503+ if vmax is None :
504+ vmax = 1.0
483505 _cax .set_clim (vmin = vmin , vmax = vmax )
484506
485507 if (
@@ -541,31 +563,16 @@ def _render_points(
541563 coords = ["x" , "y" ]
542564
543565 if table_name is not None and col_for_color not in points .columns :
544- warnings . warn (
566+ logger . warning (
545567 f"Annotating points with { col_for_color } which is stored in the table `{ table_name } `. "
546- f"To improve performance, it is advisable to store point annotations directly in the .parquet file." ,
547- UserWarning ,
548- stacklevel = 2 ,
568+ f"To improve performance, it is advisable to store point annotations directly in the .parquet file."
549569 )
550570
551571 if col_for_color is None or (
552572 table_name is not None
553573 and (col_for_color in sdata_filt [table_name ].obs .columns or col_for_color in sdata_filt [table_name ].var_names )
554574 ):
555575 points = points [coords ].compute ()
556- if (
557- col_for_color
558- and col_for_color in sdata_filt [table_name ].obs .columns
559- and (color_col := sdata_filt [table_name ].obs [col_for_color ]).dtype == "O"
560- and not _is_coercable_to_float (color_col )
561- ):
562- warnings .warn (
563- f"Converting copy of '{ col_for_color } ' column to categorical dtype for categorical "
564- f"plotting. Consider converting before plotting." ,
565- UserWarning ,
566- stacklevel = 2 ,
567- )
568- sdata_filt [table_name ].obs [col_for_color ] = sdata_filt [table_name ].obs [col_for_color ].astype ("category" )
569576 else :
570577 coords += [col_for_color ]
571578 points = points [coords ].compute ()
@@ -683,6 +690,7 @@ def _render_points(
683690 alpha = render_params .alpha ,
684691 table_name = table_name ,
685692 render_type = "points" ,
693+ coordinate_system = coordinate_system ,
686694 )
687695
688696 if added_color_from_table and col_for_color is not None :
@@ -1219,6 +1227,7 @@ def _render_labels(
12191227 cmap_params = render_params .cmap_params ,
12201228 table_name = table_name ,
12211229 table_layer = table_layer ,
1230+ coordinate_system = coordinate_system ,
12221231 )
12231232
12241233 # rasterize could have removed labels from label
0 commit comments