@@ -570,6 +570,23 @@ def _render_points(
570570 coords += [col_for_color ]
571571 points = points [coords ].compute ()
572572
573+ added_color_from_table = False
574+ if col_for_color is not None and col_for_color not in points .columns :
575+ color_values = get_values (
576+ value_key = col_for_color ,
577+ sdata = sdata_filt ,
578+ element_name = element ,
579+ table_name = table_name ,
580+ table_layer = table_layer ,
581+ )
582+ points = points .merge (
583+ color_values [[col_for_color ]],
584+ how = "left" ,
585+ left_index = True ,
586+ right_index = True ,
587+ )
588+ added_color_from_table = True
589+
573590 if groups is not None and col_for_color is not None :
574591 if col_for_color in points .columns :
575592 points_color_values = points [col_for_color ]
@@ -588,6 +605,14 @@ def _render_points(
588605 if len (points ) <= 0 :
589606 raise ValueError (f"None of the groups { groups } could be found in the column '{ col_for_color } '." )
590607
608+ n_points = len (points )
609+ points_pd_with_color = points
610+ points_for_model = (
611+ points_pd_with_color .drop (columns = [col_for_color ], errors = "ignore" )
612+ if added_color_from_table and col_for_color is not None
613+ else points_pd_with_color
614+ )
615+
591616 # we construct an anndata to hack the plotting functions
592617 if table_name is None :
593618 adata = AnnData (
@@ -617,7 +642,7 @@ def _render_points(
617642
618643 # Convert back to dask dataframe to modify sdata
619644 transformation_in_cs = sdata_filt .points [element ].attrs ["transform" ][coordinate_system ]
620- points = dask .dataframe .from_pandas (points , npartitions = 1 )
645+ points = dask .dataframe .from_pandas (points_for_model , npartitions = 1 )
621646 sdata_filt .points [element ] = PointsModel .parse (points , coordinates = {"x" : "x" , "y" : "y" })
622647 # restore transformation in coordinate system of interest
623648 set_transformation (
@@ -658,6 +683,16 @@ def _render_points(
658683 render_type = "points" ,
659684 )
660685
686+ if added_color_from_table and col_for_color is not None :
687+ points_with_color_dd = dask .dataframe .from_pandas (points_pd_with_color , npartitions = 1 )
688+ sdata_filt .points [element ] = PointsModel .parse (points_with_color_dd , coordinates = {"x" : "x" , "y" : "y" })
689+ set_transformation (
690+ element = sdata_filt .points [element ],
691+ transformation = transformation_in_cs ,
692+ to_coordinate_system = coordinate_system ,
693+ )
694+ points = points_with_color_dd
695+
661696 # color_source_vector is None when the values aren't categorical
662697 if color_source_vector is None and render_params .transfunc is not None :
663698 color_vector = render_params .transfunc (color_vector )
@@ -669,7 +704,7 @@ def _render_points(
669704 method = render_params .method
670705
671706 if method is None :
672- method = "datashader" if len ( points ) > 10000 else "matplotlib"
707+ method = "datashader" if n_points > 10000 else "matplotlib"
673708
674709 if method != "matplotlib" :
675710 # we only notify the user when we switched away from matplotlib
0 commit comments