diff --git a/examples/rgb.py b/examples/rgb.py new file mode 100644 index 00000000..d73cd789 --- /dev/null +++ b/examples/rgb.py @@ -0,0 +1,8 @@ +# /// script +# dependencies = [ +# "ndv[pyqt,vispy]", +# ] +# /// +import ndv + +n = ndv.imshow(ndv.data.rgba()) diff --git a/src/ndv/_types.py b/src/ndv/_types.py index 3988fdc4..3716fa9a 100644 --- a/src/ndv/_types.py +++ b/src/ndv/_types.py @@ -6,7 +6,7 @@ from contextlib import suppress from enum import Enum, IntFlag, auto from functools import cache -from typing import TYPE_CHECKING, Annotated, Any, NamedTuple, Optional, cast +from typing import TYPE_CHECKING, Annotated, Any, NamedTuple, cast from pydantic import PlainSerializer, PlainValidator from typing_extensions import TypeAlias @@ -52,9 +52,12 @@ def _to_slice(val: Any) -> slice: AxisKey: TypeAlias = Annotated[ Hashable, PlainValidator(_maybe_int), PlainSerializer(str, return_type=str) ] - -# A channel key is a value that can be used to identify a channel. -ChannelKey: TypeAlias = Optional[int] +# An channel key is any hashable object that can be used to describe a position along +# an axis. In many cases it will be an integer, but it might also provide a contextual +# label for one or more positions. +ChannelKey: TypeAlias = Annotated[ + Hashable, PlainValidator(_maybe_int), PlainSerializer(str, return_type=str) +] class MouseButton(IntFlag): diff --git a/src/ndv/controllers/_array_viewer.py b/src/ndv/controllers/_array_viewer.py index e7502b5a..2cd34f76 100644 --- a/src/ndv/controllers/_array_viewer.py +++ b/src/ndv/controllers/_array_viewer.py @@ -2,7 +2,7 @@ import os import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -60,13 +60,17 @@ def __init__( display_model: ArrayDisplayModel | None = None, **kwargs: Unpack[ArrayDisplayModelKwargs], ) -> None: - if display_model is not None and kwargs: + wrapper = None if data is None else DataWrapper.create(data) + if display_model is None: + display_model = self._default_display_model(wrapper, **kwargs) + elif kwargs: warnings.warn( "When display_model is provided, kwargs are be ignored.", stacklevel=2, ) + self._data_model = _ArrayDataDisplayModel( - data_wrapper=data, display=display_model or ArrayDisplayModel(**kwargs) + data_wrapper=wrapper, display=display_model ) self._viewer_model = ArrayViewerModel() self._viewer_model.events.interaction_mode.connect( @@ -218,6 +222,27 @@ def clone(self) -> ArrayViewer: # --------------------- PRIVATE ------------------------------------------ + @staticmethod + def _default_display_model( + data: None | DataWrapper, **kwargs: Unpack[ArrayDisplayModelKwargs] + ) -> ArrayDisplayModel: + """ + Creates a default ArrayDisplayModel when none is provided by the user. + + All magical setup goes here. + """ + # Can't do any magic with no data + if data is None: + return ArrayDisplayModel(**kwargs) + + # cast 3d+ images with shape[-1] of {3,4} to RGB images + if "channel_mode" not in kwargs and "channel_axis" not in kwargs: + shape = tuple(data.sizes().values()) + if len(shape) >= 3 and shape[-1] in {3, 4}: + kwargs["channel_axis"] = -1 + kwargs["channel_mode"] = "rgba" + return ArrayDisplayModel(**kwargs) + def _add_histogram(self, channel: ChannelKey = None) -> None: histogram_cls = _app.get_histogram_canvas_class() # will raise if not supported hist = histogram_cls() @@ -262,6 +287,7 @@ def _set_model_connected( for obj, callback in [ (model.events.visible_axes, self._on_model_visible_axes_changed), + (model.events.channel_axis, self._on_model_channel_axis_changed), # the current_index attribute itself is immutable (model.current_index.value_changed, self._on_model_current_index_changed), (model.events.channel_mode, self._on_model_channel_mode_changed), @@ -330,21 +356,30 @@ def _on_model_visible_axes_changed(self) -> None: self._canvas.set_ndim(self.display_model.n_visible_axes) self._request_data() + def _on_model_channel_axis_changed(self) -> None: + self._request_data() + def _on_model_current_index_changed(self) -> None: value = self._data_model.display.current_index self._view.set_current_index(value) self._request_data() def _on_model_channel_mode_changed(self, mode: ChannelMode) -> None: + # When the channel view changes, two things must be done: self._view.set_channel_mode(mode) + # 1. A slider must be shown for each axis that is not a: + # (a) channel axis + # (b) visible axis self._update_visible_sliders() - show_channel_luts = mode in {ChannelMode.COLOR, ChannelMode.COMPOSITE} + # 2. LutViews must be updated: for lut_ctrl in self._lut_controllers.values(): for view in lut_ctrl.lut_views: if lut_ctrl.key is None: - view.set_visible(not show_channel_luts) + view.set_visible(mode == ChannelMode.GRAYSCALE) + elif lut_ctrl.key == "RGB": + view.set_visible(mode == ChannelMode.RGBA) else: - view.set_visible(show_channel_luts) + view.set_visible(mode in {ChannelMode.COLOR, ChannelMode.COMPOSITE}) # redraw self._clear_canvas() self._request_data() @@ -563,7 +598,15 @@ def _get_values_at_world_point(self, x: int, y: int) -> dict[ChannelKey, float]: values: dict[ChannelKey, float] = {} for key, ctrl in self._lut_controllers.items(): if (value := ctrl.get_value_at_index((y, x))) is not None: - values[key] = value + # Handle RGB + if key == "RGB" and isinstance(value, np.ndarray): + values["R"] = value[0] + values["G"] = value[1] + values["B"] = value[2] + if value.shape[0] > 3: + values["A"] = value[3] + else: + values[key] = cast("float", value) return values diff --git a/src/ndv/controllers/_channel_controller.py b/src/ndv/controllers/_channel_controller.py index 22c1a171..af9761ca 100644 --- a/src/ndv/controllers/_channel_controller.py +++ b/src/ndv/controllers/_channel_controller.py @@ -8,12 +8,11 @@ import numpy as np + from ndv._types import ChannelKey from ndv.models._lut_model import LUTModel from ndv.views.bases import LutView from ndv.views.bases._graphics._canvas_elements import ImageHandle - LutKey = int | None - class ChannelController: """Controller for a single channel in the viewer. @@ -25,7 +24,7 @@ class ChannelController: """ def __init__( - self, key: LutKey, lut_model: LUTModel, views: Sequence[LutView] + self, key: ChannelKey, lut_model: LUTModel, views: Sequence[LutView] ) -> None: self.key = key self.lut_views: list[LutView] = [] @@ -67,7 +66,7 @@ def add_handle(self, handle: ImageHandle) -> None: self.handles.append(handle) self.add_lut_view(handle) - def get_value_at_index(self, idx: tuple[int, ...]) -> float | None: + def get_value_at_index(self, idx: tuple[int, ...]) -> np.ndarray | float | None: """Get the value of the data at the given index.""" if not (handles := self.handles): return None @@ -78,7 +77,7 @@ def get_value_at_index(self, idx: tuple[int, ...]) -> float | None: # stored by the backend visual, rather than querying the data itself # this is a quick workaround to get the value without having to # worry about other dimensions in the data source (since the - # texture has already been reduced to 2D). But a more complete + # texture has already been reduced to RGB/RGBA/2D). But a more complete # implementation would gather the full current nD index and query # the data source directly. return handle.data()[idx] # type: ignore [no-any-return] diff --git a/src/ndv/data.py b/src/ndv/data.py index 5e39ccee..f6c85082 100644 --- a/src/ndv/data.py +++ b/src/ndv/data.py @@ -144,3 +144,22 @@ def cosem_dataset( ).result() ts_array = ts_array[ts.d[:].label["z", "y", "x"]] return ts_array[ts.d[("y", "x", "z")].transpose[:]] + + +def rgba() -> np.ndarray: + """3D RGBA dataset: `(256, 256, 256, 4)`, uint8.""" + img = np.zeros((256, 256, 256, 4), dtype=np.uint8) + + # R,G,B are simple + for i in range(256): + img[:, i, :, 0] = i # Red + img[:, i, :, 2] = 255 - i # Blue + for j in range(256): + img[:, :, j, 1] = j # Green + + # Alpha is a bit trickier - requires a meshgrid for efficient computation + x, y, z = np.meshgrid(np.arange(256), np.arange(256), np.arange(256), indexing="ij") + alpha = np.sqrt((x - 128) ** 2 + (y - 128) ** 2 + (z - 128) ** 2) + img[:, :, :, 3] = np.clip(alpha, 0, 255) + + return img diff --git a/src/ndv/models/_array_display_model.py b/src/ndv/models/_array_display_model.py index 4d47fcf2..6bc65610 100644 --- a/src/ndv/models/_array_display_model.py +++ b/src/ndv/models/_array_display_model.py @@ -7,7 +7,7 @@ from pydantic import Field, computed_field, model_validator from typing_extensions import Self, TypeAlias -from ndv._types import AxisKey, Slice +from ndv._types import AxisKey, ChannelKey, Slice from ._base_model import NDVModel from ._lut_model import LUTModel @@ -47,7 +47,7 @@ class ArrayDisplayModelKwargs(TypedDict, total=False): # map of axis to index/slice ... i.e. the current subset of data being displayed IndexMap: TypeAlias = ValidatedEventedDict[AxisKey, Union[int, Slice]] # map of index along channel axis to LUTModel object -LutMap: TypeAlias = ValidatedEventedDict[Union[int, None], LUTModel] +LutMap: TypeAlias = ValidatedEventedDict[ChannelKey, LUTModel] # map of axis to reducer Reducers: TypeAlias = ValidatedEventedDict[Union[AxisKey, None], ReducerType] # used for visible_axes @@ -172,7 +172,6 @@ class ArrayDisplayModel(NDVModel): `luts`. """ - visible_axes: TwoOrThreeAxisTuple = (-2, -1) # NOTE: In terms of requesting data, there is a slight "delocalization" of state # here in that we probably also want to avoid requesting data for channel # positions that are not visible. @@ -184,6 +183,10 @@ class ArrayDisplayModel(NDVModel): channel_mode: ChannelMode = ChannelMode.GRAYSCALE channel_axis: Optional[AxisKey] = None + # must come after channel_axis, since it is used to set default visible_axes + visible_axes: TwoOrThreeAxisTuple = Field( + default_factory=lambda k: (-3, -2) if k.get("channel_axis") == -1 else (-2, -1) + ) # map of index along channel axis to LUTModel object luts: LutMap = Field(default_factory=_default_luts) diff --git a/src/ndv/models/_data_display_model.py b/src/ndv/models/_data_display_model.py index 556596a8..b8203b81 100644 --- a/src/ndv/models/_data_display_model.py +++ b/src/ndv/models/_data_display_model.py @@ -2,7 +2,6 @@ import sys from collections.abc import ( Hashable, - Iterable, Iterator, Mapping, MutableMapping, @@ -15,9 +14,10 @@ import numpy as np from pydantic import Field +from ndv._types import ChannelKey from ndv.views import _app -from ._array_display_model import ArrayDisplayModel, ChannelMode +from ._array_display_model import ArrayDisplayModel, ChannelMode, TwoOrThreeAxisTuple from ._base_model import NDVModel from ._data_wrapper import DataWrapper @@ -36,6 +36,7 @@ class DataRequest: index: Mapping[int, Union[int, slice]] visible_axes: tuple[int, ...] channel_axis: Optional[int] + channel_mode: ChannelMode @dataclass(frozen=True, **SLOTS) @@ -47,7 +48,7 @@ class DataResponse: # mapping of channel_key -> data n_visible_axes: int - data: Mapping[Optional[int], np.ndarray] = field(repr=False) + data: Mapping[ChannelKey, np.ndarray] = field(repr=False) request: Optional[DataRequest] = None @@ -82,19 +83,36 @@ class _ArrayDataDisplayModel(NDVModel): def model_post_init(self, __context: Any) -> None: # connect the channel mode change signal to the channel axis guessing method self.display.events.channel_mode.connect(self._on_channel_mode_change) + # initial model synchronization + self._on_channel_mode_change() def _on_channel_mode_change(self) -> None: - # if the mode is not grayscale, and the channel axis is not set, - # we let the data wrapper guess the channel axis - if ( - self.display.channel_mode != ChannelMode.GRAYSCALE - and self.display.channel_axis is None - and self.data_wrapper is not None - ): - # only use the guess if it's not already in the visible axes - guess = self.data_wrapper.guess_channel_axis() - if guess not in self.normed_visible_axes: + # TODO: Refactor into separate methods? + mode = self.display.channel_mode + if mode == ChannelMode.GRAYSCALE: + self.display.channel_axis = None + elif mode in {ChannelMode.COLOR, ChannelMode.COMPOSITE}: + if self.data_wrapper is not None: + guess = self.data_wrapper.guess_channel_axis() + # only use the guess if it's not already in the visible axes + self.display.channel_axis = ( + None if guess in self.normed_visible_axes else guess + ) + elif mode == ChannelMode.RGBA: + if self.data_wrapper is not None and self.display.channel_axis is None: + # Coerce image to RGB + if len(self.normed_visible_axes) == 3: + raise Exception("") + guess = self.data_wrapper.guess_channel_axis() self.display.channel_axis = guess + # FIXME? going back another ChannelMode retains these changes + if guess in self.normed_visible_axes: + dims = list(self.data_wrapper.sizes().keys()) + dims.remove(guess) + new_visible_axes = dims[-self.display.n_visible_axes :] + self.display.visible_axes = cast( + "TwoOrThreeAxisTuple", tuple(new_visible_axes) + ) # Properties for normalized data access ----------------------------------------- # these all use positive integers as axis keys @@ -223,6 +241,7 @@ def current_slice_requests(self) -> list[DataRequest]: index=requested_slice, visible_axes=self.normed_visible_axes, channel_axis=c_ax, + channel_mode=self.display.channel_mode, ) return [request] @@ -254,18 +273,18 @@ def process_request(req: DataRequest) -> DataResponse: vis_ax = req.visible_axes t_dims = vis_ax + tuple(i for i in range(data.ndim) if i not in vis_ax) - if (ch_ax := req.channel_axis) is not None: - ch_indices: Iterable[Optional[int]] = range(data.shape[ch_ax]) + data_response: dict[ChannelKey, np.ndarray] = {} + ch_ax = req.channel_axis + # For RGB and Grayscale - keep the whole array together + if req.channel_mode == ChannelMode.RGBA: + data_response["RGB"] = data.transpose(*t_dims).squeeze() + elif req.channel_axis is None: + data_response[None] = data.transpose(*t_dims).squeeze() + # For Composite and Color - slice along channel axis else: - ch_indices = (None,) - - data_response: dict[int | None, np.ndarray] = {} - for i in ch_indices: - if i is None: - ch_data = data - else: + for i in range(data.shape[req.channel_axis]): ch_keepdims = (slice(None),) * cast("int", ch_ax) + (i,) + (None,) ch_data = data[ch_keepdims] - data_response[i] = ch_data.transpose(*t_dims).squeeze() + data_response[i] = ch_data.transpose(*t_dims).squeeze() return DataResponse(n_visible_axes=len(vis_ax), data=data_response, request=req) diff --git a/src/ndv/views/_jupyter/_array_view.py b/src/ndv/views/_jupyter/_array_view.py index 74613fd8..5b790609 100644 --- a/src/ndv/views/_jupyter/_array_view.py +++ b/src/ndv/views/_jupyter/_array_view.py @@ -172,6 +172,12 @@ def frontend_widget(self) -> Any: return self.layout +class JupyterRGBView(JupyterLutView): + def __init__(self, channel: ChannelKey = None) -> None: + super().__init__(channel) + self._cmap.layout.display = "none" + + SPIN_GIF = str(Path(__file__).parent.parent / "_resources" / "spin.gif") @@ -205,7 +211,7 @@ def __init__( # the button that controls the display mode of the channels self._channel_mode_combo = widgets.Dropdown( - options=[ChannelMode.GRAYSCALE, ChannelMode.COMPOSITE], + options=[ChannelMode.GRAYSCALE, ChannelMode.COMPOSITE, ChannelMode.RGBA], value=str(ChannelMode.GRAYSCALE), ) self._channel_mode_combo.layout.width = "120px" @@ -343,7 +349,7 @@ def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: def add_lut_view(self, channel: ChannelKey) -> JupyterLutView: """Add a LUT view to the viewer.""" - wdg = JupyterLutView(channel) + wdg = JupyterRGBView(channel) if channel == "RGB" else JupyterLutView(channel) layout = self._luts_box self._luts[channel] = wdg diff --git a/src/ndv/views/_pygfx/_array_canvas.py b/src/ndv/views/_pygfx/_array_canvas.py index 43da46db..8e4ce770 100755 --- a/src/ndv/views/_pygfx/_array_canvas.py +++ b/src/ndv/views/_pygfx/_array_canvas.py @@ -57,8 +57,17 @@ def data(self) -> np.ndarray: return self._grid.data # type: ignore [no-any-return] def set_data(self, data: np.ndarray) -> None: - self._grid.data[:] = data - self._grid.update_range((0, 0, 0), self._grid.size) + # If dimensions are unchanged, reuse the buffer + if data.shape == self._grid.data.shape: + self._grid.data[:] = data + self._grid.update_range((0, 0, 0), self._grid.size) + # Otherwise, the size (and maybe number of dimensions) changed + # - we need a new buffer + else: + self._grid = pygfx.Texture(data, dim=2) + self._image.geometry = pygfx.Geometry(grid=self._grid) + # RGB images (i.e. 3D datasets) cannot have a colormap + self._material.map = None if self._is_rgb() else self._cmap.to_pygfx() def visible(self) -> bool: return bool(self._image.visible) @@ -95,7 +104,9 @@ def colormap(self) -> _cmap.Colormap: def set_colormap(self, cmap: _cmap.Colormap) -> None: self._cmap = cmap - self._material.map = cmap.to_pygfx() + # RGB (i.e. 3D) images should not have a colormap + if not self._is_rgb(): + self._material.map = cmap.to_pygfx() self._render() def start_move(self, pos: Sequence[float]) -> None: @@ -111,6 +122,9 @@ def remove(self) -> None: def get_cursor(self, mme: MouseMoveEvent) -> CursorType | None: return None + def _is_rgb(self) -> bool: + return self.data().ndim == 3 and isinstance(self._image, pygfx.Image) + class PyGFXRectangle(RectangularROIHandle): def __init__( diff --git a/src/ndv/views/_qt/_array_view.py b/src/ndv/views/_qt/_array_view.py index caa8c839..584e3955 100644 --- a/src/ndv/views/_qt/_array_view.py +++ b/src/ndv/views/_qt/_array_view.py @@ -164,14 +164,14 @@ def __init__(self, parent: QWidget | None = None) -> None: self.histogram_btn = QPushButton(add_histogram_icon, "") self.histogram_btn.setCheckable(True) - top = QHBoxLayout() - top.setSpacing(5) - top.setContentsMargins(0, 0, 0, 0) - top.addWidget(self.visible) - top.addWidget(self.cmap) - top.addWidget(self.clims) - top.addWidget(self.auto_clim) - top.addWidget(self.histogram_btn) + self._lut_layout = QHBoxLayout() + self._lut_layout.setSpacing(5) + self._lut_layout.setContentsMargins(0, 0, 0, 0) + self._lut_layout.addWidget(self.visible) + self._lut_layout.addWidget(self.cmap) + self._lut_layout.addWidget(self.clims) + self._lut_layout.addWidget(self.auto_clim) + self._lut_layout.addWidget(self.histogram_btn) self._histogram: QWidget | None = None @@ -179,7 +179,7 @@ def __init__(self, parent: QWidget | None = None) -> None: self._layout = QVBoxLayout(self) self._layout.setSpacing(0) self._layout.setContentsMargins(0, 0, 0, 0) - self._layout.addLayout(top) + self._layout.addLayout(self._lut_layout) class QLutView(LutView): @@ -253,6 +253,16 @@ def _on_q_histogram_toggled(self, toggled: bool) -> None: self.histogramRequested.emit(self._channel) +class QRGBView(QLutView): + def __init__(self, channel: ChannelKey = None) -> None: + super().__init__(channel) + # Hide the cmap selector + self._qwidget.cmap.setVisible(False) + # Insert a new label + self._label = QLabel("RGB") + self._qwidget._lut_layout.insertWidget(1, self._label) + + class ROIButton(QPushButton): def __init__(self, parent: QWidget | None = None): super().__init__(parent) @@ -390,7 +400,11 @@ def __init__(self, canvas_widget: QWidget, parent: QWidget | None = None): # not using QEnumComboBox because we want to exclude some values for now self.channel_mode_combo = QComboBox(self) self.channel_mode_combo.addItems( - [ChannelMode.GRAYSCALE.value, ChannelMode.COMPOSITE.value] + [ + ChannelMode.GRAYSCALE.value, + ChannelMode.COMPOSITE.value, + ChannelMode.RGBA.value, + ] ) # button to reset the zoom of the canvas @@ -488,7 +502,7 @@ def __init__( self._visible_axes: Sequence[AxisKey] = [] def add_lut_view(self, channel: ChannelKey) -> QLutView: - view = QLutView(channel) + view = QRGBView(channel) if channel == "RGB" else QLutView(channel) self._luts[channel] = view view.histogramRequested.connect(self.histogramRequested) diff --git a/src/ndv/views/_vispy/_array_canvas.py b/src/ndv/views/_vispy/_array_canvas.py index 5779f08a..964e5a01 100755 --- a/src/ndv/views/_vispy/_array_canvas.py +++ b/src/ndv/views/_vispy/_array_canvas.py @@ -43,7 +43,7 @@ class VispyImageHandle(ImageHandle): def __init__(self, visual: scene.Image | scene.Volume) -> None: self._visual = visual - self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3 + self._allowed_dims = {2, 3} if isinstance(visual, scene.visuals.Image) else {3} def data(self) -> np.ndarray: try: @@ -52,7 +52,7 @@ def data(self) -> np.ndarray: return self._visual._last_data # type: ignore [no-any-return] def set_data(self, data: np.ndarray) -> None: - if not data.ndim == self._ndim: + if data.ndim not in self._allowed_dims: warnings.warn( f"Got wrong number of dimensions ({data.ndim}) for vispy " f"visual of type {type(self._visual)}.", diff --git a/src/ndv/views/_wx/_array_view.py b/src/ndv/views/_wx/_array_view.py index 3ca6fad5..4fe7ab87 100644 --- a/src/ndv/views/_wx/_array_view.py +++ b/src/ndv/views/_wx/_array_view.py @@ -91,16 +91,16 @@ def __init__(self, parent: wx.Window) -> None: _add_icon(self.histogram, "foundation:graph-bar") # Layout - widget_sizer = wx.BoxSizer(wx.HORIZONTAL) - widget_sizer.Add(self.visible, 0, wx.ALL, 2) - widget_sizer.Add(self.cmap, 0, wx.ALL, 2) - widget_sizer.Add(self.clims, 1, wx.ALL, 2) - widget_sizer.Add(self.auto_clim, 0, wx.ALL, 2) - widget_sizer.Add(self.histogram, 0, wx.ALL, 2) - widget_sizer.SetSizeHints(self) + self._widget_sizer = wx.BoxSizer(wx.HORIZONTAL) + self._widget_sizer.Add(self.visible, 0, wx.ALL, 2) + self._widget_sizer.Add(self.cmap, 0, wx.ALL, 2) + self._widget_sizer.Add(self.clims, 1, wx.ALL, 2) + self._widget_sizer.Add(self.auto_clim, 0, wx.ALL, 2) + self._widget_sizer.Add(self.histogram, 0, wx.ALL, 2) + self._widget_sizer.SetSizeHints(self) self.sizer = wx.BoxSizer(wx.VERTICAL) - self.sizer.Add(widget_sizer, 0, wx.EXPAND, 5) + self.sizer.Add(self._widget_sizer, 0, wx.EXPAND, 5) self.SetSizer(self.sizer) self.Layout() @@ -211,6 +211,15 @@ def close(self) -> None: self._wxwidget.Close() +class WxRGBView(WxLutView): + def __init__(self, parent: wx.Window, channel: ChannelKey = None) -> None: + super().__init__(parent, channel) + self._wxwidget.cmap.Hide() + lbl = wx.StaticText(self._wxwidget, label="RGB") + self._wxwidget._widget_sizer.Insert(1, lbl, 0, wx.ALIGN_CENTER_VERTICAL, 5) + self._wxwidget.Layout() + + # mostly copied from _qt.qt_view._QDimsSliders class _WxDimsSliders(wx.Panel): currentIndexChanged = Signal() @@ -309,7 +318,11 @@ def __init__(self, canvas_widget: wx.Window, parent: wx.Window = None): # Channel mode combo box self.channel_mode_combo = wx.ComboBox( self, - choices=[ChannelMode.GRAYSCALE.value, ChannelMode.COMPOSITE.value], + choices=[ + ChannelMode.GRAYSCALE.value, + ChannelMode.COMPOSITE.value, + ChannelMode.RGBA.value, + ], style=wx.CB_DROPDOWN, ) @@ -419,9 +432,8 @@ def frontend_widget(self) -> wx.Window: return self._wxwidget def add_lut_view(self, channel: ChannelKey) -> WxLutView: - view = WxLutView(self.frontend_widget(), channel) - - # Add the LutView to the Viewer + wdg = self.frontend_widget() + view = WxRGBView(wdg, channel) if channel == "RGB" else WxLutView(wdg, channel) self._wxwidget.luts.Add(view._wxwidget, 0, wx.EXPAND | wx.BOTTOM, 5) self._luts[channel] = view # TODO: Reusable synchronization with ViewerModel diff --git a/src/ndv/views/bases/_array_view.py b/src/ndv/views/bases/_array_view.py index c1a6d204..38ca1e23 100644 --- a/src/ndv/views/bases/_array_view.py +++ b/src/ndv/views/bases/_array_view.py @@ -65,7 +65,7 @@ def hide_sliders( self, axes_to_hide: Container[Hashable], *, show_remainder: bool = ... ) -> None: ... @abstractmethod - def add_lut_view(self, channel: ChannelKey) -> LutView: ... + def add_lut_view(self, key: ChannelKey) -> LutView: ... @abstractmethod def remove_lut_view(self, view: LutView) -> None: ... diff --git a/tests/conftest.py b/tests/conftest.py index b8a7b542..b2611b0f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -61,7 +61,7 @@ def any_app(request: pytest.FixtureRequest) -> Iterator[Any]: if frontend == GuiFrontend.QT: app = request.getfixturevalue("qapp") qtbot = request.getfixturevalue("qtbot") - with patch.object(app, "exec", lambda *_: None): + with patch.object(app, "exec", lambda *_: app.processEvents()): with _catch_qt_leaks(request, app): yield app, qtbot elif frontend == GuiFrontend.JUPYTER: diff --git a/tests/test_controller.py b/tests/test_controller.py index 22d7d5e0..fb2e042e 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -391,3 +391,23 @@ def test_roi_interaction() -> None: (canvas_roi_start[1] + canvas_roi_end[1]) / 2, ) assert roi_view.get_cursor(mme) == CursorType.ALL_ARROW + + +@pytest.mark.allow_leaks +@pytest.mark.usefixtures("any_app") +def test_rgb_display_magic() -> None: + # FIXME: Something in the QLutView is causing leaked qt widgets here. + # Doesn't seem to be coming from the QRGBView... + def assert_rgb_magic_works(rgb_data: np.ndarray) -> None: + viewer = ArrayViewer(rgb_data) + assert viewer.display_model.channel_mode == ChannelMode.RGBA + # Note Multiple correct answers here - modulus covers both cases + assert cast("int", viewer.display_model.channel_axis) % rgb_data.ndim == 4 + assert cast("int", viewer.display_model.visible_axes[0]) % rgb_data.ndim == 2 + assert cast("int", viewer.display_model.visible_axes[1]) % rgb_data.ndim == 3 + + rgb_data = np.ones((1, 2, 3, 4, 3), dtype=np.uint8) + assert_rgb_magic_works(rgb_data) + + rgba_data = np.ones((1, 2, 3, 4, 4), dtype=np.uint8) + assert_rgb_magic_works(rgba_data)