Skip to content

Commit c454708

Browse files
Add new color_labels filter to add color scalars to labeled data (pyvista#7024)
* Initial implementation * Add test coverage * Add test coverage for inputs * Remove progress bar * Add scalars test coverage * Update colors and output scalars * Fix typing * Fix line ending * Rename coloring_mode param, update docs * Revert changes to validate_color_sequence * Move import * Update color validation * Move validation function * Revert colors changes * Fix typing * Update docstring * Update test coverage * Update test coverage and docs for partial dict input * Update docstring examples * Update docstring * Update data_set.py * Add files via upload * Update data_set.py * Fix formatting * Update types in docstrings * Support a single color * Update tests * Fix typing * Add files via upload * Add color_type and make int_rgb the default * Fix typing * Append rgb instead of rgba for rgb types * Update docs * Add test * Set 'cell' as default preference * Add initial default plot for comparison * Add clarity to color map info * Add missing word * Add files via upload --------- Co-authored-by: Tetsuo Koyama <tkoyama010@gmail.com>
1 parent 75bca11 commit c454708

10 files changed

+454
-1
lines changed

pyvista/core/_validation/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,7 @@ def _validate_color_sequence(
12771277
color_list = color_list * n_colors
12781278

12791279
# Only return if we have the correct number of colors
1280-
if n_colors is not None and len(color_list) == n_colors:
1280+
if n_colors is None or len(color_list) == n_colors:
12811281
return tuple(color_list)
12821282
except ValueError:
12831283
pass

pyvista/core/filters/data_set.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import contextlib
88
import functools
99
from typing import TYPE_CHECKING
10+
from typing import Any
1011
from typing import Literal
1112
from typing import cast
1213
import warnings
@@ -38,6 +39,7 @@
3839
from pyvista.core.utilities.transform import Transform
3940

4041
if TYPE_CHECKING: # pragma: no cover
42+
from pyvista import Color
4143
from pyvista import DataSet
4244
from pyvista import MultiBlock
4345
from pyvista import PolyData
@@ -48,6 +50,7 @@
4850
from pyvista.core._typing_core import TransformLike
4951
from pyvista.core._typing_core import VectorLike
5052
from pyvista.core._typing_core._dataset_types import ConcreteDataSetAlias
53+
from pyvista.plotting._typing import ColorLike
5154

5255

5356
@abstract_class
@@ -8942,6 +8945,311 @@ def pack_labels( # type: ignore[misc]
89428945

89438946
return result
89448947

8948+
def color_labels( # type: ignore[misc]
8949+
self: ConcreteDataSetType,
8950+
colors: str
8951+
| ColorLike
8952+
| Sequence[ColorLike]
8953+
| dict[float, ColorLike] = 'glasbey_category10',
8954+
*,
8955+
coloring_mode: Literal['index', 'cycler'] | None = None,
8956+
color_type: Literal['int_rgb', 'float_rgb', 'int_rgba', 'float_rgba'] = 'int_rgb',
8957+
scalars: str | None = None,
8958+
preference: Literal['point', 'cell'] = 'cell',
8959+
output_scalars: str | None = None,
8960+
inplace: bool = False,
8961+
):
8962+
"""Add RGB(A) scalars to labeled data.
8963+
8964+
This filter adds a color array to map label values to specific colors.
8965+
The mapping can be specified explicitly with a dictionary or implicitly
8966+
with a colormap or sequence of colors. The implicit mapping is controlled
8967+
with two coloring modes:
8968+
8969+
- ``'index'`` : The input scalar values (label ids) are used as index values for
8970+
indexing the specified ``colors``. This creates a direct relationship
8971+
between labels and colors such that a given label will always have the same
8972+
color, regardless of the number of labels present in the dataset.
8973+
8974+
This option is used by default for unsigned 8-bit integer inputs, i.e.
8975+
scalars with whole numbers and a maximum range of ``[0, 255]``.
8976+
8977+
- ``'cycler'`` : The specified ``colors`` are cycled through sequentially,
8978+
and each unique value in the input scalars is assigned a color in increasing
8979+
order. Unlike with ``'index'`` mode, the colors are not directly mapped to
8980+
the labels, but instead depends on the number of labels at the input.
8981+
8982+
This option is used by default for floating-point inputs or for inputs
8983+
with values out of the range ``[0, 255]``.
8984+
8985+
By default, a new ``'int_rgb'`` array is added with the same name as the
8986+
specified ``scalars`` but with ``_rgb`` appended.
8987+
8988+
See Also
8989+
--------
8990+
pyvista.ImageDataFilters.contour_labels
8991+
Generate contours from labeled image data. The contours may be colored with this filter.
8992+
8993+
pack_labels
8994+
Make labeled data contiguous. May be used as a pre-processing step before
8995+
coloring.
8996+
8997+
Parameters
8998+
----------
8999+
colors : str | ColorLike | Sequence[ColorLike] | dict[float, ColorLike], default: 'glasbey_category10'
9000+
Color(s) to use. Specify a dictionary to explicitly control the mapping
9001+
from label values to colors. Alternatively, specify colors only using a
9002+
colormap or a sequence of colors and use ``coloring_mode`` to implicitly
9003+
control the mapping. A single color is also supported to color the entire
9004+
mesh with one color.
9005+
9006+
By default, a variation of the ``'glasbey'`` categorical colormap is used
9007+
where the first 10 colors are the same default colors used by ``matplotlib``.
9008+
See `colorcet categorical colormaps <https://colorcet.holoviz.org/user_guide/Categorical.html#>`_
9009+
for more information.
9010+
9011+
.. note::
9012+
When a dictionary is specified, any scalar values for which a key is
9013+
not provided is assigned default RGB(A) values of ``nan`` for float colors
9014+
or ``0`` for integer colors (see ``color_type``). To ensure the color
9015+
array has no default values, be sure to provide a mapping for any and
9016+
all possible input label values.
9017+
9018+
coloring_mode : 'index' | 'cycler', optional
9019+
Control how colors are mapped to label values. Has no effect if ``colors``
9020+
is a dictionary. Specify one of:
9021+
9022+
- ``'index'``: The input scalar values (label ids) are used as index
9023+
values for indexing the specified ``colors``.
9024+
- ``'cycler'``: The specified ``'colors'`` are cycled through sequentially,
9025+
and each unique value in the input scalars is assigned a color in increasing
9026+
order.
9027+
9028+
color_type : 'int_rgb' | 'float_rgb' | 'int_rgba' | 'float_rgba', default: 'int_rgb'
9029+
Type of the color array to store. By default, the colors are stored as
9030+
RGB integers to reduce memory usage.
9031+
9032+
.. note::
9033+
The color type affects the default value for unspecified colors when
9034+
a dictionary is used. See ``colors`` for details.
9035+
9036+
scalars : str, optional
9037+
Name of scalars with labels. Defaults to currently active scalars.
9038+
9039+
preference : str, default: "cell"
9040+
When ``scalars`` is specified, this is the preferred array
9041+
type to search for in the dataset. Must be either
9042+
``'point'`` or ``'cell'``.
9043+
9044+
output_scalars : str, optional
9045+
Name of the color scalars array. By default, the output array
9046+
is the same as ``scalars`` with `_rgb`` or ``_rgba`` appended
9047+
depending on ``color_type``.
9048+
9049+
inplace : bool, default: False
9050+
If ``True``, the mesh is updated in-place.
9051+
9052+
Returns
9053+
-------
9054+
pyvista.DataSet
9055+
Dataset with RGB(A) scalars. Output type matches input type.
9056+
9057+
Examples
9058+
--------
9059+
Load labeled data and crop it to simplify the data.
9060+
9061+
>>> from pyvista import examples
9062+
>>> import numpy as np
9063+
>>> image_labels = examples.load_channels()
9064+
>>> image_labels = image_labels.extract_subset(voi=(75, 109, 75, 109, 85, 100))
9065+
9066+
Plot the dataset with default coloring using a categorical color map. The
9067+
plotter by default uniformly samples from all 256 colors in the color map based
9068+
on the data's range.
9069+
9070+
>>> image_labels.plot(cmap='glasbey_category10')
9071+
9072+
Show label ids of the dataset.
9073+
9074+
>>> label_ids = np.unique(image_labels.active_scalars)
9075+
>>> label_ids
9076+
pyvista_ndarray([0, 1, 2, 3, 4])
9077+
9078+
Color the labels with the filter then plot them. Note that the
9079+
``'glasbey_category10'`` color map is used by default.
9080+
9081+
>>> colored_labels = image_labels.color_labels()
9082+
>>> colored_labels.plot()
9083+
9084+
Since the labels are unsigned integers, the ``'index'`` coloring mode is used
9085+
by default. Unlike the uniform sampling used by the plotter in the previous
9086+
plot, the colormap is instead indexed using the label values. This ensures
9087+
that labels have a consistent coloring regardless of the input. For example,
9088+
we can crop the dataset further.
9089+
9090+
>>> subset_labels = image_labels.extract_subset(voi=(15, 34, 28, 34, 12, 15))
9091+
9092+
And show that only three labels remain.
9093+
9094+
>>> label_ids = np.unique(subset_labels.active_scalars)
9095+
>>> label_ids
9096+
pyvista_ndarray([1, 2, 3])
9097+
9098+
Despite the changes to the dataset, the regions have the same coloring
9099+
as before.
9100+
9101+
>>> colored_labels = subset_labels.color_labels()
9102+
>>> colored_labels.plot()
9103+
9104+
Use the ``'cycler'`` coloring mode instead to map label values to colors
9105+
sequentially.
9106+
9107+
>>> colored_labels = subset_labels.color_labels(coloring_mode='cycler')
9108+
>>> colored_labels.plot()
9109+
9110+
Map the colors explicitly using a dictionary.
9111+
9112+
>>> colors = {0: 'black', 1: 'red', 2: 'lime', 3: 'blue', 4: 'yellow'}
9113+
>>> colored_labels = image_labels.color_labels(colors)
9114+
>>> colored_labels.plot()
9115+
9116+
Omit the background value from the mapping and specify float colors. When
9117+
floats are specified, values without a mapping are assigned ``nan`` values
9118+
and are not plotted by default.
9119+
9120+
>>> colors.pop(0)
9121+
'black'
9122+
>>> colored_labels = image_labels.color_labels(colors, color_type='float_rgba')
9123+
>>> colored_labels.plot()
9124+
9125+
Color all labels with a single color.
9126+
9127+
>>> colored_labels = image_labels.color_labels('red')
9128+
>>> colored_labels.plot()
9129+
9130+
"""
9131+
# Lazy import since these are from plotting module
9132+
from cycler import cycler
9133+
import matplotlib.colors
9134+
9135+
from pyvista.core._validation.validate import _validate_color_sequence
9136+
from pyvista.plotting._typing import ColorLike
9137+
from pyvista.plotting.colors import get_cmap_safe
9138+
9139+
def _local_validate_color_sequence(seq: ColorLike | Sequence[ColorLike]) -> Sequence[Color]:
9140+
try:
9141+
return _validate_color_sequence(seq)
9142+
except ValueError:
9143+
raise ValueError(
9144+
'Invalid colors. Colors must be one of:\n'
9145+
' - sequence of color-like values,\n'
9146+
' - dict with color-like values,\n'
9147+
' - named colormap string.\n'
9148+
f'Got: {seq}'
9149+
)
9150+
9151+
def _is_index_like(array_, max_value):
9152+
if np.issubdtype(array_.dtype, np.integer) or np.array_equal(array, np.floor(array_)):
9153+
min_, max_ = output_mesh.get_data_range(name)
9154+
if min_ >= 0 and max_ <= max_value:
9155+
return True
9156+
return False
9157+
9158+
_validation.check_contains(
9159+
['int_rgb', 'float_rgb', 'int_rgba', 'float_rgba'],
9160+
must_contain=color_type,
9161+
name='color_type',
9162+
)
9163+
9164+
if 'rgba' in color_type:
9165+
num_components = 4
9166+
scalars_suffix = '_rgba'
9167+
else:
9168+
num_components = 3
9169+
scalars_suffix = '_rgb'
9170+
if 'float' in color_type:
9171+
default_channel_value = np.nan
9172+
color_dtype = 'float'
9173+
else:
9174+
default_channel_value = 0
9175+
color_dtype = 'uint8'
9176+
9177+
if scalars is None:
9178+
field, name = set_default_active_scalars(self)
9179+
else:
9180+
name = scalars
9181+
field = get_array_association(self, name, preference=preference, err=True)
9182+
output_mesh = self if inplace else self.copy()
9183+
data = output_mesh.point_data if field == FieldAssociation.POINT else output_mesh.cell_data
9184+
array = data[name]
9185+
9186+
if isinstance(colors, dict):
9187+
if coloring_mode is not None:
9188+
raise TypeError('Coloring mode cannot be set when a color dictionary is specified.')
9189+
colors_ = _local_validate_color_sequence(cast(list[ColorLike], list(colors.values())))
9190+
color_rgb_sequence = [getattr(c, color_type) for c in colors_]
9191+
items = zip(colors.keys(), color_rgb_sequence)
9192+
9193+
else:
9194+
_is_rgb_sequence = False
9195+
if isinstance(colors, str):
9196+
try:
9197+
cmap = get_cmap_safe(colors)
9198+
except ValueError:
9199+
pass
9200+
else:
9201+
if not isinstance(cmap, matplotlib.colors.ListedColormap):
9202+
raise ValueError(
9203+
f"Colormap '{colors}' must be a ListedColormap, got {cmap.__class__.__name__} instead."
9204+
)
9205+
# Avoid unnecessary conversion and set color sequence directly in float cases
9206+
cmap_colors = cast(list[list[float]], cmap.colors)
9207+
if color_type == 'float_rgb':
9208+
color_rgb_sequence = cmap_colors
9209+
_is_rgb_sequence = True
9210+
elif color_type == 'float_rgba':
9211+
color_rgb_sequence = [(*c, 1.0) for c in cmap_colors]
9212+
_is_rgb_sequence = True
9213+
else:
9214+
colors = cmap_colors
9215+
9216+
if not _is_rgb_sequence:
9217+
color_rgb_sequence = [
9218+
getattr(c, color_type) for c in _local_validate_color_sequence(colors)
9219+
]
9220+
if len(color_rgb_sequence) == 1:
9221+
color_rgb_sequence = color_rgb_sequence * len(array)
9222+
9223+
n_colors = len(color_rgb_sequence)
9224+
if coloring_mode is None:
9225+
coloring_mode = 'index' if _is_index_like(array, max_value=n_colors) else 'cycler'
9226+
9227+
if coloring_mode == 'index':
9228+
if not _is_index_like(array, max_value=n_colors):
9229+
raise ValueError(
9230+
f"Index coloring mode cannot be used with scalars '{name}'. Scalars must be positive integers \n"
9231+
f'and the max value ({self.get_data_range(name)[1]}) must be less than the number of colors ({n_colors}).'
9232+
)
9233+
keys: Iterable[float] = range(n_colors)
9234+
values: Iterable[Any] = color_rgb_sequence
9235+
else:
9236+
keys = np.unique(array)
9237+
values = cycler('color', color_rgb_sequence)
9238+
9239+
items = zip(keys, values)
9240+
9241+
colors_out = np.full((len(array), num_components), default_channel_value, dtype=color_dtype)
9242+
for label, color in items:
9243+
if isinstance(color, dict):
9244+
color = color['color']
9245+
colors_out[array == label, :] = color
9246+
9247+
colors_name = name + scalars_suffix if output_scalars is None else output_scalars
9248+
data[colors_name] = colors_out
9249+
output_mesh.set_active_scalars(colors_name)
9250+
9251+
return output_mesh
9252+
89459253

89469254
def _set_threshold_limit(alg, value, method, invert):
89479255
"""Set vtkThreshold limits and function.

0 commit comments

Comments
 (0)