Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions example_ndv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# sanity checks for ndv example

try:
import ndv
except ImportError:
raise ImportError("You need to `pip install ndv` to run this example.")
from typing import cast

from qtpy import QT6

if not QT6:
raise ImportError("ndv will require QT>=6.4.")

from careamics.utils import get_careamics_home
from careamics_portfolio import PortfolioManager
from napari.qt import get_stylesheet

from careamics_napari.training_plugin import TrainPlugin, TrainPluginWrapper
from careamics_napari.widgets import PredictDataWidget

# create a viewer and your main gui
viewer = ndv.ArrayViewer()
scroll = TrainPluginWrapper(viewer)
# apply the napari stylesheet
scroll.setStyleSheet(get_stylesheet("dark"))
scroll.show()

# get sample data
files = PortfolioManager().denoising.N2V_SEM.download(path=get_careamics_home())

# my gross hack to pre-populate the text fields with the sample files :)
wdg = cast("TrainPlugin", scroll.widget())
wdg.data_layers[0].train_images_folder.text_field.setText(files[-2])
wdg.data_layers[0].val_images_folder.text_field.setText(files[-1])
pred_wdg = wdg.prediction_widget.findChild(PredictDataWidget)
pred_wdg.pred_images_folder.text_field.setText(files[-1])

# run the Qt event loop
ndv.run_app()
36 changes: 22 additions & 14 deletions src/careamics_napari/training_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pathlib import Path
from queue import Queue
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from careamics import CAREamist
from careamics.config.support import SupportedAlgorithm
Expand Down Expand Up @@ -45,6 +45,7 @@
import numpy as np

if TYPE_CHECKING:
import ndv
import napari

# at run time
Expand Down Expand Up @@ -101,7 +102,7 @@ class TrainPlugin(QWidget):

def __init__(
self: Self,
napari_viewer: Optional[napari.Viewer] = None,
napari_viewer: Union[napari.Viewer, "ndv.ArrayViewer"] = None,
) -> None:
"""Initialize the plugin.

Expand Down Expand Up @@ -345,21 +346,28 @@ def _update_from_prediction(self, update: PredictionUpdate) -> None:
)

else:
if update.type == PredictionUpdateType.SAMPLE:
if update.type == PredictionUpdateType.SAMPLE and self.viewer is not None:
# add image to napari
# TODO keep scaling?
if self.viewer is not None:
# value is eighter a numpy array or a list of numpy arrays with each sample/timepoint as an element
if isinstance(update.value, list):
# combine all samples
samples = np.concatenate(update.value, axis=0)
else:
samples = update.value

# reshape the prediction to match the input axes
samples = reshape_prediction(samples, self.train_config_signal.axes, self.pred_config_signal.is_3d)

# value is eighter a numpy array or a list of numpy arrays with each sample/timepoint as an element
if isinstance(update.value, list):
# combine all samples
samples = np.concatenate(update.value, axis=0)
else:
samples = update.value

# reshape the prediction to match the input axes
samples = reshape_prediction(samples, self.train_config_signal.axes, self.pred_config_signal.is_3d)

# if isinstance(self.viewer, napari.Viewer):
# elif isinstance(self.viewer, ndv.ArrayViewer):
if type(self.viewer).__module__.startswith("napari"):
self.viewer.add_image(samples, name="Prediction")
elif type(self.viewer).__module__.startswith("ndv"):
# one could also instantiate and show new ndv.ArrayViewer here
self.viewer.data = samples
self.viewer.show()

else:
self.pred_status.update(update)

Expand Down