diff --git a/src/imagej/_java.py b/src/imagej/_java.py index 1de17a29..e7d8fb25 100644 --- a/src/imagej/_java.py +++ b/src/imagej/_java.py @@ -30,6 +30,10 @@ class MyJavaClasses(JavaClasses): significantly easier and more readable. """ + @JavaClasses.java_import + def Double(self): + return "java.lang.Double" + @JavaClasses.java_import def Throwable(self): return "java.lang.Throwable" @@ -50,6 +54,62 @@ def MetadataWrapper(self): def LabelingIOService(self): return "io.scif.labeling.LabelingIOService" + @JavaClasses.java_import + def ChapmanRichardsAxis(self): + return "net.imagej.axis.ChapmanRichardsAxis" + + @JavaClasses.java_import + def DefaultLinearAxis(self): + return "net.imagej.axis.DefaultLinearAxis" + + @JavaClasses.java_import + def EnumeratedAxis(self): + return "net.imagej.axis.EnumeratedAxis" + + @JavaClasses.java_import + def ExponentialAxis(self): + return "net.imagej.axis.ExponentialAxis" + + @JavaClasses.java_import + def ExponentialRecoveryAxis(self): + return "net.imagej.axis.ExponentialRecoveryAxis" + + @JavaClasses.java_import + def GammaVariateAxis(self): + return "net.imagej.axis.GammaVariateAxis" + + @JavaClasses.java_import + def GaussianAxis(self): + return "net.imagej.axis.GaussianAxis" + + @JavaClasses.java_import + def IdentityAxis(self): + return "net.imagej.axis.IdentityAxis" + + @JavaClasses.java_import + def InverseRodbardAxis(self): + return "net.imagej.axis.InverseRodbardAxis" + + @JavaClasses.java_import + def LogLinearAxis(self): + return "net.imagej.axis.LogLinearAxis" + + @JavaClasses.java_import + def PolynomialAxis(self): + return "net.imagej.axis.PolynomialAxis" + + @JavaClasses.java_import + def PowerAxis(self): + return "net.imagej.axis.PowerAxis" + + @JavaClasses.java_import + def RodbardAxis(self): + return "net.imagej.axis.RodbardAxis" + + @JavaClasses.java_import + def VariableAxis(self): + return "net.imagej.axis.VariableAxis" + @JavaClasses.java_import def Dataset(self): return "net.imagej.Dataset" @@ -106,6 +166,10 @@ def ImgView(self): def ImgLabeling(self): return "net.imglib2.roi.labeling.ImgLabeling" + @JavaClasses.java_import + def IntegerType(self): + return "net.imglib2.type.numeric.IntegerType" + @JavaClasses.java_import def Named(self): return "org.scijava.Named" diff --git a/src/imagej/array.py b/src/imagej/array.py new file mode 100644 index 00000000..7c939e96 --- /dev/null +++ b/src/imagej/array.py @@ -0,0 +1,122 @@ +import numpy as np +import xarray as xr +from scyjava import _convert + +import imagej.dims as dims + + +@xr.register_dataarray_accessor("img") +class ImgAccessor: + def __init__(self, xarr): + self._data = xarr + + @property + def is_rgb(self): + """ + Returns True or False if the xarray.DataArray is an RGB image. + + :return: Boolean + """ + ch_labels = ["c", "ch", "Channel"] + # check if array is signed + if self._data.min() < 0: + return False + # check if array is integer dtype + if not np.issubdtype(self._data.data.dtype, np.integer): + return False + # check bitsperpixel + if self._data.dtype.itemsize * 8 != 8: + return False + # check if "channel" present + if not any(dim in self._data.dims for dim in ch_labels): + return False + # check channel length = 3 exactly + for dim in self._data.dims: + if dim in ch_labels: + loc = self._data.dims.index(dim) + if self._data.shape[loc] != 3: + return False + + return True + + +@xr.register_dataarray_accessor("metadata") +class MetadataAccessor: + def __init__(self, xarr): + self._data = xarr + self._update() + + @property + def axes(self): + """ + Returns a tuple of the ImageJ axes. + + :return: A Python tuple of the ImageJ axes. + """ + return ( + tuple(self._data.attrs["imagej"].get("scifio.metadata.image").get("axes")) + if "scifio.metadata.image" in self._data.attrs["imagej"] + else None + ) + + def set(self, metadata: dict): + """ + Set the metadata of the parent xarray.DataArray. + + :param metadata: A Python dict representing the image metadata. + """ + self._data.attrs["imagej"] = metadata + + def get(self): + """ + Get the metadata dict of the the parent xarray.DataArray. + + :return: A Python dict representing the image metadata. + """ + return self._data.attrs["imagej"] + + def tree(self): + """ + Print a tree of the metadata of the parent xarray.DataArray. + """ + self._print_dict_tree(self._data.attrs["imagej"]) + + def _print_dict_tree(self, dictionary, indent="", prefix=""): + for idx, (key, value) in enumerate(dictionary.items()): + if idx == len(dictionary) - 1: + connector = "└──" + else: + connector = "├──" + print(indent + connector + prefix + " " + str(key)) + if isinstance(value, (dict, _convert.JavaMap)): + if idx == len(dictionary) - 1: + self._print_dict_tree(value, indent + " ", prefix="── ") + else: + self._print_dict_tree(value, indent + "│ ", prefix="── ") + + def _update(self): + if self._data.attrs.get("imagej"): + # update axes + axes = [None] * len(self._data.dims) + for i in range(len(self.axes)): + ax_label = dims._convert_dim(self.axes[i].type().getLabel(), "python") + if ax_label in self._data.dims: + axes[self._data.dims.index(ax_label)] = self.axes[i] + self._data.attrs["imagej"].get("scifio.metadata.image", {})["axes"] = axes + + # update axis lengths + old_ax_len_metadata = ( + self._data.attrs["imagej"] + .get("scifio.metadata.image", {}) + .get("axisLengths", {}) + ) + new_ax_len_metadata = {} + for i in range(len(self.axes)): + ax_type = self.axes[i].type() + if ax_type in old_ax_len_metadata.keys(): + ax_label = dims._convert_dim(ax_type.getLabel(), "python") + curr_ax_len = self._data.shape[self._data.dims.index(ax_label)] + new_ax_len_metadata[ax_type] = curr_ax_len + self._data.attrs["imagej"].get("scifio.metadata.image", {})[ + "axisLengths" + ] = new_ax_len_metadata diff --git a/src/imagej/convert.py b/src/imagej/convert.py index a3bcdb89..1ca46984 100644 --- a/src/imagej/convert.py +++ b/src/imagej/convert.py @@ -13,6 +13,7 @@ from jpype import JByte, JException, JFloat, JLong, JObject, JShort from labeling import Labeling +import imagej.array # noqa:F401 import imagej.dims as dims import imagej.images as images from imagej._java import jc @@ -166,7 +167,10 @@ def xarray_to_dataset(ij: "jc.ImageJ", xarr) -> "jc.Dataset": axes = dims._assign_axes(xarr) dataset.setAxes(axes) dataset.setName(xarr.name) - _assign_dataset_metadata(dataset, xarr.attrs) + if hasattr(xarr, "metadata"): + _assign_dataset_metadata(dataset, xarr.metadata.get()) + else: + _assign_dataset_metadata(dataset, xarr.attrs["imagej"]) return dataset @@ -230,15 +234,18 @@ def java_to_xarray(ij: "jc.ImageJ", jobj) -> xr.DataArray: assert hasattr(permuted_rai, "dim_axes") xr_axes = list(permuted_rai.dim_axes) xr_dims = list(permuted_rai.dims) - xr_attrs = sj.to_python(permuted_rai.getProperties()) - xr_attrs = {sj.to_python(k): sj.to_python(v) for k, v in xr_attrs.items()} # reverse axes and dims to match narr xr_axes.reverse() xr_dims.reverse() xr_dims = dims._convert_dims(xr_dims, direction="python") xr_coords = dims._get_axes_coords(xr_axes, xr_dims, narr.shape) name = jobj.getName() if isinstance(jobj, jc.Named) else None - return xr.DataArray(narr, dims=xr_dims, coords=xr_coords, attrs=xr_attrs, name=name) + xr_attrs = {"imagej": {}} + xarr = xr.DataArray(narr, dims=xr_dims, coords=xr_coords, name=name, attrs=xr_attrs) + # use the MetadataAccessor to add metadata to the xarray + xarr.metadata.set(dict(sj.to_python(permuted_rai.getProperties()))) + xarr.metadata._update() + return xarr def supports_java_to_ndarray(ij: "jc.ImageJ", obj) -> bool: @@ -509,7 +516,7 @@ def metadata_wrapper_to_dict(ij: "jc.ImageJ", metadata_wrapper: "jc.MetadataWrap #################### -def _assign_dataset_metadata(dataset: "jc.Dataset", attrs): +def _assign_dataset_metadata(dataset: "jc.Dataset", attrs: dict): """ :param dataset: ImageJ2 Dataset :param attrs: Dictionary containing metadata diff --git a/src/imagej/dims.py b/src/imagej/dims.py index f009c3a0..eab37e33 100644 --- a/src/imagej/dims.py +++ b/src/imagej/dims.py @@ -2,12 +2,12 @@ Utility functions for querying and manipulating dimensional axis metadata. """ import logging -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import scyjava as sj import xarray as xr -from jpype import JException, JObject +from jpype import JObject from imagej._java import jc from imagej.images import is_arraylike as _is_arraylike @@ -177,49 +177,40 @@ def prioritize_rai_axes_order( return permute_order -def _assign_axes(xarr: xr.DataArray): +def _assign_axes( + xarr: xr.DataArray, +) -> List[Union["jc.DefaultLinearAxis", "jc.EnumeratedAxis"]]: """ - Obtain xarray axes names, origin, and scale and convert into ImageJ Axis; - currently supports EnumeratedAxis - :param xarr: xarray that holds the units - :return: A list of ImageJ Axis with the specified origin and scale + Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both + DefaultLinearAxis and the newer EnumeratedAxis. + :param xarr: xarray that holds the data. + :return: A list of ImageJ Axis with the specified origin and scale. """ - Double = sj.jimport("java.lang.Double") - - axes = [""] * len(xarr.dims) - - # try to get EnumeratedAxis, if not then default to LinearAxis in the loop - try: - EnumeratedAxis = _get_enumerated_axis() - except (JException, TypeError): - EnumeratedAxis = None - - for dim in xarr.dims: - axis_str = _convert_dim(dim, direction="java") + axes = [""] * xarr.ndim + for i in range(xarr.ndim): + dim = xarr.dims[i] + axis_str = _convert_dim(dim, "java") ax_type = jc.Axes.get(axis_str) ax_num = _get_axis_num(xarr, dim) - scale = _get_scale(xarr.coords[dim]) + coords_arr = xarr.coords[dim].to_numpy() - if scale is None: + # check if coords/scale is numeric + if _is_numeric_scale(coords_arr): + doub_coords = [jc.Double(np.double(x)) for x in xarr.coords[dim]] + else: _logger.warning( f"The {ax_type.label} axis is non-numeric and is translated " "to a linear index." ) doub_coords = [ - Double(np.double(x)) for x in np.arange(len(xarr.coords[dim])) + jc.Double(np.double(x)) for x in np.arrange(len(xarr.coords[dim])) ] - else: - doub_coords = [Double(np.double(x)) for x in xarr.coords[dim]] - # EnumeratedAxis is a new axis made for xarray, so is only present in - # ImageJ versions that are released later than March 2020. - # This actually returns a LinearAxis if using an earlier version. - if EnumeratedAxis is not None: - java_axis = EnumeratedAxis(ax_type, sj.to_java(doub_coords)) + # use the xarr metadata if available to assign axes + if hasattr(xarr, "metadata") and xarr.metadata.axes: + axes[ax_num] = xarr.metadata.axes[i] else: - java_axis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) - - axes[ax_num] = java_axis + axes[ax_num] = _get_fallback_linear_axis(ax_type, doub_coords) return axes @@ -295,27 +286,28 @@ def _get_scale(axis): return None -def _get_enumerated_axis(): - """Get EnumeratedAxis. - - EnumeratedAxis is only in releases later than March 2020. If using - an older version of ImageJ without EnumeratedAxis, use - _get_linear_axis() instead. +def _is_numeric_scale(coords_array: np.ndarray) -> bool: """ - return sj.jimport("net.imagej.axis.EnumeratedAxis") + Checks if the coordinates array of the given axis is numeric. + :param coords_array: A 1D NumPy array. + :return: bool + """ + return np.issubdtype(coords_array.dtype, np.number) -def _get_linear_axis(axis_type: "jc.AxisType", values): - """Get linear axis. - This is used if no EnumeratedAxis is found. If EnumeratedAxis - is available, use _get_enumerated_axis() instead. +def _get_fallback_linear_axis(axis_type: "jc.AxisType", values): + """ + Get a DefaultLinearAxis manually when all other axes + resources are unavailable. """ - DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis") origin = values[0] - scale = values[1] - values[0] - axis = DefaultLinearAxis(axis_type, scale, origin) - return axis + # calculate the slope using the values/coord array + if len(values) <= 1: + scale = 1 + else: + scale = values[1] - values[0] + return jc.DefaultLinearAxis(axis_type, scale, origin) def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus": diff --git a/tests/test_image_conversion.py b/tests/test_image_conversion.py index 977ce47f..562a6d31 100644 --- a/tests/test_image_conversion.py +++ b/tests/test_image_conversion.py @@ -115,7 +115,10 @@ def assert_inverted_xarr_equal_to_xarr(dataset, ij_fixture, xarr): assert list(xarr.dims) == list(invert_xarr.dims) for key in xarr.coords: assert (xarr.coords[key] == invert_xarr.coords[key]).all() - assert xarr.attrs == invert_xarr.attrs + if "Hello" in xarr.attrs.keys(): + assert xarr.attrs["Hello"] == invert_xarr.attrs["Hello"] + else: + assert xarr.attrs == invert_xarr.attrs assert xarr.name == invert_xarr.name