diff --git a/.travis/build.sh b/.travis/build.sh index 5a6fd77d..421037c1 100755 --- a/.travis/build.sh +++ b/.travis/build.sh @@ -16,9 +16,9 @@ check () { } # -- create a test enviroment -- -conda env create -q -f environment.yml +conda create -n imagej -y python=$TRAVIS_PYTHON_VERSION source activate imagej -conda install -q -y python=$TRAVIS_PYTHON_VERSION +conda env update -f environment.yml # -- ensure supporting tools are available -- check curl git unzip @@ -58,6 +58,9 @@ ij_dir=$HOME/Fiji.app echo "ij_dir = $ij_dir" python setup.py install +# -- unset JAVA_HOME in case it was set -- +unset JAVA_HOME + # -- run tests with local Fiji.app -- python -O test/test_imagej.py --ij "$ij_dir" diff --git a/README.md b/README.md index 4eb3d8bc..e2974242 100644 --- a/README.md +++ b/README.md @@ -11,43 +11,65 @@ see "ImageJ Server" below for details. 1. Install [Conda](https://conda.io/): * On Windows, install Conda using [Chocolatey](https://chocolatey.org): `choco install miniconda3` - * On macOS, install Conda using [Homebrew](https://brew.sh): `brew install miniconda` + * On macOS, install Conda using [Homebrew](https://brew.sh): `brew cask install miniconda` * On Linux, install Conda using its [RPM or Debian package](https://www.anaconda.com/rpm-and-debian-repositories-for-miniconda/), or [with the Miniconda install script](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html). -2. [Activate the conda-forge channel](https://conda-forge.org/docs/user/introduction.html#how-can-i-install-packages-from-conda-forge): +2. Configure your shell for use with conda: + ``` + conda init bash + ``` + Where `bash` is the shell you use. + Then start a new shell instance. + +3. [Activate the conda-forge channel](https://conda-forge.org/docs/user/introduction.html#how-can-i-install-packages-from-conda-forge): ``` conda config --add channels conda-forge conda config --set channel_priority strict ``` -3. Install pyimagej into a new conda environment: +4. Install pyimagej into a new conda environment: ``` conda create -n pyimagej pyimagej openjdk=8 ``` -4. Whenever you want to use pyimagej, activate its environment: +5. Whenever you want to use pyimagej, activate its environment: ``` conda activate pyimagej ``` ### Installation asides -* If you want to use [scikit-image](https://scikit-image.org/) in conjunction, as demonstrated below, you can install it also via: - +* If you want to use [scikit-image](https://scikit-image.org/) in conjunction, + as demonstrated below, you can install it also via: ``` conda install scikit-image ``` -* The above command installs pyimagej with OpenJDK 8; if you leave off the `openjdk=8` it will install OpenJDK 11 by default, which should also work, but is less well tested and may have more rough edges. +* The above command installs pyimagej with OpenJDK 8; if you leave off the + `openjdk=8` it will install OpenJDK 11 by default, which should also work, but + is less well tested and may have more rough edges. * It is possible to dynamically install pyimagej from within a Jupyter notebook: - ``` import sys !conda install --yes --prefix {sys.prefix} -c conda-forge pyimagej openjdk=8 ``` - -* If you would prefer to install pyimagej via pip, more legwork is required. See [this thread](https://forum.image.sc/t/how-do-i-install-pyimagej/23189/4) for hints. + This approach is useful for [JupyterHub](https://jupyter.org/hub) on the + cloud, e.g. [Binder](https://mybinder.org/), to utilize pyimagej in select + notebooks without advance installation. This reduces time needed to create + and launch the environment, at the expense of a longer startup time the first + time a pyimagej-enabled notebook is run. See [this itkwidgets example + notebook](https://github.com/InsightSoftwareConsortium/itkwidgets/blob/v0.24.2/examples/ImageJImgLib2.ipynb) + for an example. + +* It is possible to dynamically install pyimagej on + [Google Colab](https://colab.research.google.com/). See + [this thread](https://forum.image.sc/t/pyimagej-on-google-colab/32804) for + guidance. A major advantage of Google Colab is free GPU in the cloud. + +* If you would prefer to install pyimagej via pip, more legwork is required. + See [this thread](https://forum.image.sc/t/how-do-i-install-pyimagej/23189/4) + for hints. ## Usage diff --git a/environment.yml b/environment.yml index 90013d8e..a12deed6 100644 --- a/environment.yml +++ b/environment.yml @@ -2,11 +2,12 @@ name: imagej channels: - conda-forge dependencies: - - imglyb + - imglyb=0.3.5 - matplotlib - numpy - openjdk=8 - - pyjnius + - pyjnius=1.2.0 - scyjava - pillow # for server - requests # for server + - xarray diff --git a/imagej/imagej.py b/imagej/imagej.py index 3bacf762..202e8fa9 100644 --- a/imagej/imagej.py +++ b/imagej/imagej.py @@ -13,6 +13,8 @@ import jnius_config from pathlib import Path import numpy +import xarray as xr +import warnings _logger = logging.getLogger(__name__) @@ -124,6 +126,8 @@ def init(ij_dir_or_version_or_endpoint=None, headless=True, new_instance=False): # Must import imglyb (not scyjava) to spin up the JVM now. import imglyb from jnius import autoclass + from jnius import cast + import scyjava # Initialize ImageJ. ImageJ = autoclass('net.imagej.ImageJ') @@ -134,7 +138,11 @@ def init(ij_dir_or_version_or_endpoint=None, headless=True, new_instance=False): from scyjava import jclass, isjava, to_java, to_python Dataset = autoclass('net.imagej.Dataset') + ImgPlus = autoclass('net.imagej.ImgPlus') + Img = autoclass('net.imglib2.img.Img') RandomAccessibleInterval = autoclass('net.imglib2.RandomAccessibleInterval') + Axes = autoclass('net.imagej.axis.Axes') + DefaultLinearAxis = autoclass('net.imagej.axis.DefaultLinearAxis') class ImageJPython: def __init__(self, ij): @@ -286,19 +294,159 @@ def run_script(self, language, script, args=None): def to_java(self, data): """ - Converts the data into a java equivalent. For numpy arrays, the java image points to the python array + Converts the data into a java equivalent. For numpy arrays, the java image points to the python array. + + In addition to the scyjava types, we allow ndarray-like and xarray-like variables """ if self._is_memoryarraylike(data): return imglyb.to_imglib(data) + if self._is_xarraylike(data): + return self.to_dataset(data) return to_java(data) def to_dataset(self, data): + """Converts the data into an ImageJ dataset""" + if self._is_xarraylike(data): + return self._xarray_to_dataset(data) + if self._is_arraylike(data): + return self._numpy_to_dataset(data) + if scyjava.isjava(data): + return self._java_to_dataset(data) + + raise TypeError(f'Type not supported: {type(data)}') + + def _numpy_to_dataset(self, data): + rai = imglyb.to_imglib(data) + return self._java_to_dataset(rai) + + def _xarray_to_dataset(self, xarr): + """ + Converts a xarray dataarray with specified dim order to an image + :param xarr: Pass an xarray dataarray and turn into a dataset. + :return: The dataset + """ + dataset = self._numpy_to_dataset(xarr.values) + axes = self._assign_axes(xarr) + dataset.setAxes(axes) + + # Currently, we have no handling for nonlinear axes, but I thought it should warn instead of fail. + if not self._axis_is_linear(xarr.coords): + warnings.warn("Not all axes are linear. The nonlinear axes are not mapped correctly.") + + self._assign_dataset_metadata(dataset, xarr.attrs) + + return dataset + + def _assign_axes(self, xarr): + """ + Obtain xarray axes names, origin, and scale and convert into ImageJ Axis; currently supports DefaultLinearAxis. + :param xarr: xarray that holds the units + :return: A list of ImageJ Axis with the specified origin and scale + """ + axes = ['']*len(xarr.dims) + + for axis in xarr.dims: + origin = self._get_origin(xarr.coords[axis]) + scale = self._get_scale(xarr.coords[axis]) + + axisStr = self._pydim_to_ijdim(axis) + + ax_type = Axes.get(axisStr) + ax_num = self._get_axis_num(xarr, axis) + if scale is None: + java_axis = DefaultLinearAxis(ax_type) + else: + java_axis = DefaultLinearAxis(ax_type, numpy.double(scale), numpy.double(origin)) + + axes[ax_num] = java_axis + + return axes + + def _pydim_to_ijdim(self, axis): + """Convert between the lowercase Python convention (x, y, z, c, t) to IJ (X, Y, Z, C, T)""" + if str(axis) in ['x', 'y', 'z', 'c', 't']: + return str(axis).upper() + return str(axis) + + def _ijdim_to_pydim(self, axis): + """Convert the IJ uppercase dimension convention (X, Y, Z< C, T) to lowercase python (x, y, z, c, t) """ + if str(axis) in ['X', 'Y', 'Z', 'C', 'T']: + return str(axis).lower() + return str(axis) + + def _get_axis_num(self, xarr, axis): + """ + Get the xarray -> java axis number due to inverted axis order for C style numpy arrays (default) + :param xarr: Xarray to convert + :param axis: Axis number to convert + :return: Axis idx in java + """ + py_axnum = xarr.get_axis_num(axis) + if numpy.isfortran(xarr.values): + return py_axnum + + return len(xarr.dims) - py_axnum - 1 + + + def _assign_dataset_metadata(self, dataset, attrs): + """ + :param dataset: ImageJ Java dataset + :param attrs: Dictionary containing metadata + """ + dataset.getProperties().putAll(self.to_java(attrs)) + + def _axis_is_linear(self, coords): + """ + Check if each axis has linear steps between grid points. Skip over axes with non-numeric entries + :param coords: Xarray coords variable, which is a dict with axis: [axis values] + :return: Whether all axes are linear, or not. + """ + linear = True + for coord, values in coords.items(): + try: + diff = numpy.diff(coords) + if len(numpy.unique(diff)) > 1: + warnings.warn(f'Axis {coord} is not linear') + linear = False + except TypeError: + continue + return linear + + def _get_origin(self, axis): + """ + Get the coordinate origin of an axis, assuming it is the first entry. + :param axis: A 1D list like entry accessible with indexing, which contains the axis coordinates + :return: The origin for this axis. + """ + return axis.values[0] + + def _get_scale(self, axis): + """ + Get the scale of an axis, assuming it is linear and so the scale is simply second - first coordinate. + :param axis: A 1D list like entry accessible with indexing, which contains the axis coordinates + :return: The scale for this axis or None if it is a non-numeric scale. + """ + try: + return axis.values[1] - axis.values[0] + except TypeError: + return None + + def _java_to_dataset(self, data): """ Converts the data into a ImageJ Dataset """ try: if self._ij.convert().supports(data, Dataset): return self._ij.convert().convert(data, Dataset) + if self._ij.convert().supports(data, ImgPlus): + imgPlus = self._ij.convert().convert(data, ImgPlus) + return self._ij.dataset().create(imgPlus) + if self._ij.convert().supports(data, Img): + img = self._ij.convert().convert(data, Img) + return self._ij.dataset().create(ImgPlus(img)) + if self._ij.convert().supports(data, RandomAccessibleInterval): + rai = self._ij.convert().convert(data, RandomAccessibleInterval) + return self._ij.dataset().create(rai) except Exception as exc: _dump_exception(exc) raise exc @@ -308,11 +456,14 @@ def from_java(self, data): """ Converts the data into a python equivalent """ + # todo: convert a datset to xarray + if not isjava(data): return data try: if self._ij.convert().supports(data, Dataset): # HACK: Converter exists for ImagePlus -> Dataset, but not ImagePlus -> RAI. data = self._ij.convert().convert(data, Dataset) + return self._dataset_to_xarray(data) if (self._ij.convert().supports(data, RandomAccessibleInterval)): rai = self._ij.convert().convert(data, RandomAccessibleInterval) return self.rai_to_numpy(rai) @@ -321,6 +472,37 @@ def from_java(self, data): raise exc return to_python(data) + def _dataset_to_xarray(self, dataset): + """ + Converts an ImageJ dataset into an xarray + :param dataset: ImageJ dataset + :return: xarray with reversed (C-style) dims and coords as labeled by the dataset + """ + attrs = self._ij.py.from_java(dataset.getProperties()) + axes = [(cast('net.imagej.axis.DefaultLinearAxis', dataset.axis(idx))) + for idx in range(dataset.numDimensions())] + + dims = [self._ijdim_to_pydim(axes[idx].type().getLabel()) for idx in range(len(axes))] + values = self.rai_to_numpy(dataset) + coords = self._get_axes_coords(axes, dims, numpy.shape(numpy.transpose(values))) + + xarr = xr.DataArray(values, dims=list(reversed(dims)), coords=coords, attrs=attrs) + return xarr + + def _get_axes_coords(self, axes, dims, shape): + """ + Get xarray style coordinate list dictionary from a dataset + :param axes: List of ImageJ axes + :param dims: List of axes labels for each dataset axis + :param shape: F-style, or reversed C-style, shape of axes numpy array. + :return: Dictionary of coordinates for each axis. + """ + coords = {dims[idx]: numpy.arange(axes[idx].origin(), shape[idx]*axes[idx].scale() + axes[idx].origin(), + axes[idx].scale()) + for idx in range(len(dims))} + return coords + + def show(self, image, cmap=None): """ Display a java or python 2D image. @@ -350,6 +532,12 @@ def _is_memoryarraylike(self, arr): hasattr(arr, 'data') and \ type(arr.data).__name__ == 'memoryview' + def _is_xarraylike(self, xarr): + return hasattr(xarr, 'values') and \ + hasattr(xarr, 'dims') and \ + hasattr(xarr, 'coords') and \ + self._is_arraylike(xarr.values) + def _assemble_plugin_macro(self, plugin: str, args=None, ij1_style=True): """ Assemble an ImageJ macro string given a plugin to run and optional arguments in a dict diff --git a/test/test_imagej.py b/test/test_imagej.py index 849a3f5c..fe6ddfd6 100644 --- a/test/test_imagej.py +++ b/test/test_imagej.py @@ -16,9 +16,9 @@ ij = imagej.init(ij_dir) -from jnius import autoclass +from jnius import autoclass, cast import numpy as np - +import xarray as xr class TestImageJ(unittest.TestCase): @@ -107,6 +107,62 @@ def main(self): unittest.main() +class TestXarrayConversion(unittest.TestCase): + def testCstyleArrayWithLabeledDimsConverts(self): + xarr = xr.DataArray(np.random.rand(5, 4, 3, 6, 12), dims=['T', 'Z', 'C', 'Y', 'X'], + coords={'X': range(0, 12), 'Y': np.arange(0, 12, 2), 'C': ['R', 'G', 'B'], + 'Z': np.arange(10, 50, 10), 'T': np.arange(0, 0.05, 0.01)}, + attrs={'Hello': 'Wrld'}) + + dataset = ij.py.to_java(xarr) + axes = [cast('net.imagej.axis.DefaultLinearAxis', dataset.axis(axnum)) for axnum in range(5)] + labels = [axis.type().getLabel() for axis in axes] + origins = [axis.origin() for axis in axes] + scales = [axis.scale() for axis in axes] + + self.assertListEqual(origins, [0, 0, 0, 10, 0]) + self.assertListEqual(scales, [1, 2, 1, 10, 0.01]) + + self.assertListEqual(list(reversed(xarr.dims)), labels) + + self.assertEqual(xarr.attrs, ij.py.from_java(dataset.getProperties())) + + def testFstyleArrayWiathLabeledDimsConverts(self): + xarr = xr.DataArray(np.ndarray([5, 4, 3, 6, 12], order='F'), dims=['t', 'z', 'c', 'y', 'x'], + coords={'x': range(0, 12), 'y': np.arange(0, 12, 2), + 'z': np.arange(10, 50, 10), 't': np.arange(0, 0.05, 0.01)}, + attrs={'Hello': 'Wrld'}) + + dataset = ij.py.to_java(xarr) + axes = [cast('net.imagej.axis.DefaultLinearAxis', dataset.axis(axnum)) for axnum in range(5)] + labels = [axis.type().getLabel() for axis in axes] + origins = [axis.origin() for axis in axes] + scales = [axis.scale() for axis in axes] + + self.assertListEqual(origins, [0, 10, 0, 0, 0]) + self.assertListEqual(scales, [0.01, 10, 1, 2, 1]) + + self.assertListEqual([dim.upper() for dim in xarr.dims], labels) + self.assertEqual(xarr.attrs, ij.py.from_java(dataset.getProperties())) + + def testDatasetConvertsToXarray(self): + xarr = xr.DataArray(np.random.rand(5, 4, 3, 6, 12), dims=['t', 'z', 'c', 'y', 'x'], + coords={'x': list(range(0, 12)), 'y': list(np.arange(0, 12, 2)), 'c': [0, 1, 2], + 'z': list(np.arange(10, 50, 10)), 't': list(np.arange(0, 0.05, 0.01))}, + attrs={'Hello': 'Wrld'}) + + dataset = ij.py.to_java(xarr) + + invert_xarr = ij.py.from_java(dataset) + self.assertTrue((xarr.values == invert_xarr.values).all()) + + self.assertEqual(list(xarr.dims), list(invert_xarr.dims)) + for key in xarr.coords: + self.assertTrue((xarr.coords[key] == invert_xarr.coords[key]).all()) + self.assertEqual(xarr.attrs, invert_xarr.attrs) + + + if __name__ == '__main__': unittest.main()