Skip to content

Commit b07ada7

Browse files
authored
When using render_points() with datashader, points can now be colored by adata.obs (#511)
1 parent 04e17d4 commit b07ada7

File tree

3 files changed

+91
-2
lines changed

3 files changed

+91
-2
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,23 @@ def _render_points(
570570
coords += [col_for_color]
571571
points = points[coords].compute()
572572

573+
added_color_from_table = False
574+
if col_for_color is not None and col_for_color not in points.columns:
575+
color_values = get_values(
576+
value_key=col_for_color,
577+
sdata=sdata_filt,
578+
element_name=element,
579+
table_name=table_name,
580+
table_layer=table_layer,
581+
)
582+
points = points.merge(
583+
color_values[[col_for_color]],
584+
how="left",
585+
left_index=True,
586+
right_index=True,
587+
)
588+
added_color_from_table = True
589+
573590
if groups is not None and col_for_color is not None:
574591
if col_for_color in points.columns:
575592
points_color_values = points[col_for_color]
@@ -588,6 +605,14 @@ def _render_points(
588605
if len(points) <= 0:
589606
raise ValueError(f"None of the groups {groups} could be found in the column '{col_for_color}'.")
590607

608+
n_points = len(points)
609+
points_pd_with_color = points
610+
points_for_model = (
611+
points_pd_with_color.drop(columns=[col_for_color], errors="ignore")
612+
if added_color_from_table and col_for_color is not None
613+
else points_pd_with_color
614+
)
615+
591616
# we construct an anndata to hack the plotting functions
592617
if table_name is None:
593618
adata = AnnData(
@@ -617,7 +642,7 @@ def _render_points(
617642

618643
# Convert back to dask dataframe to modify sdata
619644
transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system]
620-
points = dask.dataframe.from_pandas(points, npartitions=1)
645+
points = dask.dataframe.from_pandas(points_for_model, npartitions=1)
621646
sdata_filt.points[element] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})
622647
# restore transformation in coordinate system of interest
623648
set_transformation(
@@ -658,6 +683,16 @@ def _render_points(
658683
render_type="points",
659684
)
660685

686+
if added_color_from_table and col_for_color is not None:
687+
points_with_color_dd = dask.dataframe.from_pandas(points_pd_with_color, npartitions=1)
688+
sdata_filt.points[element] = PointsModel.parse(points_with_color_dd, coordinates={"x": "x", "y": "y"})
689+
set_transformation(
690+
element=sdata_filt.points[element],
691+
transformation=transformation_in_cs,
692+
to_coordinate_system=coordinate_system,
693+
)
694+
points = points_with_color_dd
695+
661696
# color_source_vector is None when the values aren't categorical
662697
if color_source_vector is None and render_params.transfunc is not None:
663698
color_vector = render_params.transfunc(color_vector)
@@ -669,7 +704,7 @@ def _render_points(
669704
method = render_params.method
670705

671706
if method is None:
672-
method = "datashader" if len(points) > 10000 else "matplotlib"
707+
method = "datashader" if n_points > 10000 else "matplotlib"
673708

674709
if method != "matplotlib":
675710
# we only notify the user when we switched away from matplotlib
26.5 KB
Loading

tests/pl/test_render_points.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,32 @@ def test_plot_datashader_can_color_by_category(self, sdata_blobs: SpatialData):
178178
method="datashader",
179179
).pl.show()
180180

181+
def test_plot_datashader_colors_from_table_obs(self, sdata_blobs: SpatialData):
182+
n_obs = len(sdata_blobs["blobs_points"])
183+
obs = pd.DataFrame(
184+
{
185+
"instance_id": np.arange(n_obs),
186+
"region": pd.Categorical(["blobs_points"] * n_obs),
187+
"foo": pd.Categorical(np.where(np.arange(n_obs) % 2 == 0, "a", "b")),
188+
}
189+
)
190+
191+
table = TableModel.parse(
192+
adata=AnnData(get_standard_RNG().normal(size=(n_obs, 3)), obs=obs),
193+
region="blobs_points",
194+
region_key="region",
195+
instance_key="instance_id",
196+
)
197+
sdata_blobs["datashader_table"] = table
198+
199+
sdata_blobs.pl.render_points(
200+
"blobs_points",
201+
color="foo",
202+
table_name="datashader_table",
203+
method="datashader",
204+
size=5,
205+
).pl.show()
206+
181207
def test_plot_datashader_can_use_sum_as_reduction(self, sdata_blobs: SpatialData):
182208
sdata_blobs.pl.render_points(
183209
element="blobs_points",
@@ -487,3 +513,31 @@ def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
487513
table_name="other_table",
488514
).pl.show()
489515
)
516+
517+
518+
def test_datashader_colors_points_from_table_obs(sdata_blobs: SpatialData):
519+
# Fast regression for https://github.com/scverse/spatialdata-plot/issues/479.
520+
n_obs = len(sdata_blobs["blobs_points"])
521+
obs = pd.DataFrame(
522+
{
523+
"instance_id": np.arange(n_obs),
524+
"region": pd.Categorical(["blobs_points"] * n_obs),
525+
"foo": pd.Categorical(np.where(np.arange(n_obs) % 2 == 0, "a", "b")),
526+
}
527+
)
528+
529+
table = TableModel.parse(
530+
adata=AnnData(get_standard_RNG().normal(size=(n_obs, 3)), obs=obs),
531+
region="blobs_points",
532+
region_key="region",
533+
instance_key="instance_id",
534+
)
535+
sdata_blobs["datashader_table"] = table
536+
537+
sdata_blobs.pl.render_points(
538+
"blobs_points",
539+
color="foo",
540+
table_name="datashader_table",
541+
method="datashader",
542+
size=5,
543+
).pl.show()

0 commit comments

Comments
 (0)