diff --git a/example_ndv.py b/example_ndv.py new file mode 100644 index 0000000..b9a70a7 --- /dev/null +++ b/example_ndv.py @@ -0,0 +1,39 @@ +# sanity checks for ndv example + +try: + import ndv +except ImportError: + raise ImportError("You need to `pip install ndv` to run this example.") +from typing import cast + +from qtpy import QT6 + +if not QT6: + raise ImportError("ndv will require QT>=6.4.") + +from careamics.utils import get_careamics_home +from careamics_portfolio import PortfolioManager +from napari.qt import get_stylesheet + +from careamics_napari.training_plugin import TrainPlugin, TrainPluginWrapper +from careamics_napari.widgets import PredictDataWidget + +# create a viewer and your main gui +viewer = ndv.ArrayViewer() +scroll = TrainPluginWrapper(viewer) +# apply the napari stylesheet +scroll.setStyleSheet(get_stylesheet("dark")) +scroll.show() + +# get sample data +files = PortfolioManager().denoising.N2V_SEM.download(path=get_careamics_home()) + +# my gross hack to pre-populate the text fields with the sample files :) +wdg = cast("TrainPlugin", scroll.widget()) +wdg.data_layers[0].train_images_folder.text_field.setText(files[-2]) +wdg.data_layers[0].val_images_folder.text_field.setText(files[-1]) +pred_wdg = wdg.prediction_widget.findChild(PredictDataWidget) +pred_wdg.pred_images_folder.text_field.setText(files[-1]) + +# run the Qt event loop +ndv.run_app() diff --git a/src/careamics_napari/training_plugin.py b/src/careamics_napari/training_plugin.py index 8ce52c0..cccfe87 100644 --- a/src/careamics_napari/training_plugin.py +++ b/src/careamics_napari/training_plugin.py @@ -2,7 +2,7 @@ from pathlib import Path from queue import Queue -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from careamics import CAREamist from careamics.config.support import SupportedAlgorithm @@ -45,6 +45,7 @@ import numpy as np if TYPE_CHECKING: + import ndv import napari # at run time @@ -101,7 +102,7 @@ class TrainPlugin(QWidget): def __init__( self: Self, - napari_viewer: Optional[napari.Viewer] = None, + napari_viewer: Union[napari.Viewer, "ndv.ArrayViewer"] = None, ) -> None: """Initialize the plugin. @@ -345,21 +346,28 @@ def _update_from_prediction(self, update: PredictionUpdate) -> None: ) else: - if update.type == PredictionUpdateType.SAMPLE: + if update.type == PredictionUpdateType.SAMPLE and self.viewer is not None: # add image to napari # TODO keep scaling? - if self.viewer is not None: - # value is eighter a numpy array or a list of numpy arrays with each sample/timepoint as an element - if isinstance(update.value, list): - # combine all samples - samples = np.concatenate(update.value, axis=0) - else: - samples = update.value - - # reshape the prediction to match the input axes - samples = reshape_prediction(samples, self.train_config_signal.axes, self.pred_config_signal.is_3d) - + # value is eighter a numpy array or a list of numpy arrays with each sample/timepoint as an element + if isinstance(update.value, list): + # combine all samples + samples = np.concatenate(update.value, axis=0) + else: + samples = update.value + + # reshape the prediction to match the input axes + samples = reshape_prediction(samples, self.train_config_signal.axes, self.pred_config_signal.is_3d) + + # if isinstance(self.viewer, napari.Viewer): + # elif isinstance(self.viewer, ndv.ArrayViewer): + if type(self.viewer).__module__.startswith("napari"): self.viewer.add_image(samples, name="Prediction") + elif type(self.viewer).__module__.startswith("ndv"): + # one could also instantiate and show new ndv.ArrayViewer here + self.viewer.data = samples + self.viewer.show() + else: self.pred_status.update(update)