Skip to content

render_shapes now respects the cmap parameter #436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 87 additions & 20 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_get_extent_and_range_for_datashader_canvas,
_get_linear_colormap,
_get_transformation_matrix_for_datashader,
_hex_no_alpha,
_is_coercable_to_float,
_map_color_seg,
_maybe_set_colors,
Expand Down Expand Up @@ -191,7 +192,10 @@ def _render_shapes(
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
)
transformed_element = ShapesModel.parse(
gpd.GeoDataFrame(data=sdata_filt.shapes[element].drop("geometry", axis=1), geometry=transformed_element)
gpd.GeoDataFrame(
data=sdata_filt.shapes[element].drop("geometry", axis=1),
geometry=transformed_element,
)
)

plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
Expand All @@ -208,15 +212,23 @@ def _render_shapes(
aggregate_with_reduction = None
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
if color_by_categorical:
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.by(col_for_color, ds.count()))
agg = cvs.polygons(
transformed_element,
geometry="geometry",
agg=ds.by(col_for_color, ds.count()),
)
else:
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean"
logger.info(
f'Using the datashader reduction "{reduction_name}". "max" will give an output very close '
"to the matplotlib result."
)
agg = _datashader_aggregate_with_function(
render_params.ds_reduction, cvs, transformed_element, col_for_color, "shapes"
render_params.ds_reduction,
cvs,
transformed_element,
col_for_color,
"shapes",
)
# save min and max values for drawing the colorbar
aggregate_with_reduction = (agg.min(), agg.max())
Expand Down Expand Up @@ -246,7 +258,7 @@ def _render_shapes(
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)

color_key = (
[x[:-2] for x in color_vector.categories.values]
[_hex_no_alpha(x) for x in color_vector.categories.values]
if (type(color_vector) is pd.core.arrays.categorical.Categorical)
and (len(color_vector.categories.values) > 1)
else None
Expand All @@ -257,7 +269,7 @@ def _render_shapes(
if color_vector is not None:
ds_cmap = color_vector[0]
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
ds_cmap = ds_cmap[:-2]
ds_cmap = _hex_no_alpha(ds_cmap)

ds_result = _datashader_map_aggregate_to_color(
agg,
Expand All @@ -272,7 +284,10 @@ def _render_shapes(
# else: all elements would get alpha=0 and the color bar would have a weird range
if aggregate_with_reduction[0] == aggregate_with_reduction[1]:
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
aggregate_with_reduction = (
aggregate_with_reduction[0],
aggregate_with_reduction[0] + 1,
)

ds_result = _datashader_map_aggregate_to_color(
agg,
Expand Down Expand Up @@ -468,7 +483,9 @@ def _render_points(
# we construct an anndata to hack the plotting functions
if table_name is None:
adata = AnnData(
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
X=points[["x", "y"]].values,
obs=points[coords].reset_index(),
dtype=points[["x", "y"]].values.dtype,
)
else:
adata_obs = sdata_filt[table_name].obs
Expand Down Expand Up @@ -496,7 +513,9 @@ def _render_points(
sdata_filt.points[element] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})
# restore transformation in coordinate system of interest
set_transformation(
element=sdata_filt.points[element], transformation=transformation_in_cs, to_coordinate_system=coordinate_system
element=sdata_filt.points[element],
transformation=transformation_in_cs,
to_coordinate_system=coordinate_system,
)

if col_for_color is not None:
Expand Down Expand Up @@ -586,7 +605,11 @@ def _render_points(
"to the matplotlib result."
)
agg = _datashader_aggregate_with_function(
render_params.ds_reduction, cvs, transformed_element, col_for_color, "points"
render_params.ds_reduction,
cvs,
transformed_element,
col_for_color,
"points",
)
# save min and max values for drawing the colorbar
aggregate_with_reduction = (agg.min(), agg.max())
Expand Down Expand Up @@ -642,7 +665,10 @@ def _render_points(
# else: all elements would get alpha=0 and the color bar would have a weird range
if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]):
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
aggregate_with_reduction = (
aggregate_with_reduction[0],
aggregate_with_reduction[0] + 1,
)

ds_result = _datashader_map_aggregate_to_color(
agg,
Expand Down Expand Up @@ -805,7 +831,12 @@ def _render_images(

# norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
_ax_show_and_transform(
layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm
layer,
trans_data,
ax,
cmap=cmap,
zorder=render_params.zorder,
norm=render_params.cmap_params.norm,
)

if legend_params.colorbar:
Expand All @@ -832,7 +863,11 @@ def _render_images(
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
stacked = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
/ n_channels
)
stacked = stacked[:, :, :3]
logger.warning(
Expand All @@ -844,7 +879,13 @@ def _render_images(
"Consider using 'palette' instead."
)

_ax_show_and_transform(stacked, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
_ax_show_and_transform(
stacked,
trans_data,
ax,
render_params.alpha,
zorder=render_params.zorder,
)

# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
elif palette is None and not got_multiple_cmaps:
Expand All @@ -858,7 +899,13 @@ def _render_images(
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
colored = colored[:, :, :3]

_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
_ax_show_and_transform(
colored,
trans_data,
ax,
render_params.alpha,
zorder=render_params.zorder,
)

# 2C) Image has n channels and palette info
elif palette is not None and not got_multiple_cmaps:
Expand All @@ -869,16 +916,32 @@ def _render_images(
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
colored = colored[:, :, :3]

_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
_ax_show_and_transform(
colored,
trans_data,
ax,
render_params.alpha,
zorder=render_params.zorder,
)

elif palette is None and got_multiple_cmaps:
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
colored = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
/ n_channels
)
colored = colored[:, :, :3]

_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
_ax_show_and_transform(
colored,
trans_data,
ax,
render_params.alpha,
zorder=render_params.zorder,
)

elif palette is not None and got_multiple_cmaps:
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
Expand Down Expand Up @@ -999,7 +1062,9 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
# outline-only case
elif render_params.fill_alpha == 0.0 and render_params.outline_alpha > 0.0:
cax = _draw_labels(
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
seg_erosionpx=render_params.contour_px,
seg_boundaries=True,
alpha=render_params.outline_alpha,
)
alpha_to_decorate_ax = render_params.outline_alpha

Expand All @@ -1010,7 +1075,9 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)

# ... then overlay the contour
cax_contour = _draw_labels(
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
seg_erosionpx=render_params.contour_px,
seg_boundaries=True,
alpha=render_params.outline_alpha,
)

# pass the less-transparent _cax for the legend
Expand All @@ -1035,7 +1102,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
legend_fontweight=legend_params.legend_fontweight,
legend_loc=legend_params.legend_loc,
legend_fontoutline=legend_params.legend_fontoutline,
na_in_legend=legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector)),
na_in_legend=(legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector))),
colorbar=legend_params.colorbar,
scalebar_dx=scalebar_params.scalebar_dx,
scalebar_units=scalebar_params.scalebar_units,
Expand Down
Loading
Loading