Skip to content

Commit a51a81f

Browse files
mpinkertctrueden
authored andcommitted
Implement xarray <-> dataset
This new convention supports bi-directional conversion between xarrays and datasets, under the assumption that an xarray is an image. Currently, this does flip the axes order for C-style (default) indexed xarrays, as Java uses F-style indexing. This behavior conforms to the current status for regular numpy arrays. Right now the conversion also assumes that the axes are linear, such that the axes can be defined with just an origin and a scale (aX + b). Any non-numeric axes labels are currently lost (e.g., if you have coords of [R, G, B] they become [0, 1, 2] upon conversion)
1 parent 7513560 commit a51a81f

File tree

3 files changed

+248
-3
lines changed

3 files changed

+248
-3
lines changed

Diff for: environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ dependencies:
1010
- scyjava
1111
- pillow # for server
1212
- requests # for server
13+
- xarray

Diff for: imagej/imagej.py

+189-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import jnius_config
1414
from pathlib import Path
1515
import numpy
16+
import xarray as xr
17+
import warnings
1618

1719
_logger = logging.getLogger(__name__)
1820

@@ -124,6 +126,8 @@ def init(ij_dir_or_version_or_endpoint=None, headless=True, new_instance=False):
124126
# Must import imglyb (not scyjava) to spin up the JVM now.
125127
import imglyb
126128
from jnius import autoclass
129+
from jnius import cast
130+
import scyjava
127131

128132
# Initialize ImageJ.
129133
ImageJ = autoclass('net.imagej.ImageJ')
@@ -134,7 +138,11 @@ def init(ij_dir_or_version_or_endpoint=None, headless=True, new_instance=False):
134138
from scyjava import jclass, isjava, to_java, to_python
135139

136140
Dataset = autoclass('net.imagej.Dataset')
141+
ImgPlus = autoclass('net.imagej.ImgPlus')
142+
Img = autoclass('net.imglib2.img.Img')
137143
RandomAccessibleInterval = autoclass('net.imglib2.RandomAccessibleInterval')
144+
Axes = autoclass('net.imagej.axis.Axes')
145+
DefaultLinearAxis = autoclass('net.imagej.axis.DefaultLinearAxis')
138146

139147
class ImageJPython:
140148
def __init__(self, ij):
@@ -286,19 +294,159 @@ def run_script(self, language, script, args=None):
286294

287295
def to_java(self, data):
288296
"""
289-
Converts the data into a java equivalent. For numpy arrays, the java image points to the python array
297+
Converts the data into a java equivalent. For numpy arrays, the java image points to the python array.
298+
299+
In addition to the scyjava types, we allow ndarray-like and xarray-like variables
290300
"""
291301
if self._is_memoryarraylike(data):
292302
return imglyb.to_imglib(data)
303+
if self._is_xarraylike(data):
304+
return self.to_dataset(data)
293305
return to_java(data)
294306

295307
def to_dataset(self, data):
308+
"""Converts the data into an ImageJ dataset"""
309+
if self._is_xarraylike(data):
310+
return self._xarray_to_dataset(data)
311+
if self._is_arraylike(data):
312+
return self._numpy_to_dataset(data)
313+
if scyjava.isjava(data):
314+
return self._java_to_dataset(data)
315+
316+
raise TypeError(f'Type not supported: {type(data)}')
317+
318+
def _numpy_to_dataset(self, data):
319+
rai = imglyb.to_imglib(data)
320+
return self._java_to_dataset(rai)
321+
322+
def _xarray_to_dataset(self, xarr):
323+
"""
324+
Converts a xarray dataarray with specified dim order to an image
325+
:param xarr: Pass an xarray dataarray and turn into a dataset.
326+
:return: The dataset
327+
"""
328+
dataset = self._numpy_to_dataset(xarr.values)
329+
axes = self._assign_axes(xarr)
330+
dataset.setAxes(axes)
331+
332+
# Currently, we have no handling for nonlinear axes, but I thought it should warn instead of fail.
333+
if not self._axis_is_linear(xarr.coords):
334+
warnings.warn("Not all axes are linear. The nonlinear axes are not mapped correctly.")
335+
336+
self._assign_dataset_metadata(dataset, xarr.attrs)
337+
338+
return dataset
339+
340+
def _assign_axes(self, xarr):
341+
"""
342+
Obtain xarray axes names, origin, and scale and convert into ImageJ Axis; currently supports DefaultLinearAxis.
343+
:param xarr: xarray that holds the units
344+
:return: A list of ImageJ Axis with the specified origin and scale
345+
"""
346+
axes = ['']*len(xarr.dims)
347+
348+
for axis in xarr.dims:
349+
origin = self._get_origin(xarr.coords[axis])
350+
scale = self._get_scale(xarr.coords[axis])
351+
352+
axisStr = self._pydim_to_ijdim(axis)
353+
354+
ax_type = Axes.get(axisStr)
355+
ax_num = self._get_axis_num(xarr, axis)
356+
if scale is None:
357+
java_axis = DefaultLinearAxis(ax_type)
358+
else:
359+
java_axis = DefaultLinearAxis(ax_type, numpy.double(scale), numpy.double(origin))
360+
361+
axes[ax_num] = java_axis
362+
363+
return axes
364+
365+
def _pydim_to_ijdim(self, axis):
366+
"""Convert between the lowercase Python convention (x, y, z, c, t) to IJ (X, Y, Z, C, T)"""
367+
if str(axis) in ['x', 'y', 'z', 'c', 't']:
368+
return str(axis).upper()
369+
return str(axis)
370+
371+
def _ijdim_to_pydim(self, axis):
372+
"""Convert the IJ uppercase dimension convention (X, Y, Z< C, T) to lowercase python (x, y, z, c, t) """
373+
if str(axis) in ['X', 'Y', 'Z', 'C', 'T']:
374+
return str(axis).lower()
375+
return str(axis)
376+
377+
def _get_axis_num(self, xarr, axis):
378+
"""
379+
Get the xarray -> java axis number due to inverted axis order for C style numpy arrays (default)
380+
:param xarr: Xarray to convert
381+
:param axis: Axis number to convert
382+
:return: Axis idx in java
383+
"""
384+
py_axnum = xarr.get_axis_num(axis)
385+
if numpy.isfortran(xarr.values):
386+
return py_axnum
387+
388+
return len(xarr.dims) - py_axnum - 1
389+
390+
391+
def _assign_dataset_metadata(self, dataset, attrs):
392+
"""
393+
:param dataset: ImageJ Java dataset
394+
:param attrs: Dictionary containing metadata
395+
"""
396+
dataset.getProperties().putAll(self.to_java(attrs))
397+
398+
def _axis_is_linear(self, coords):
399+
"""
400+
Check if each axis has linear steps between grid points. Skip over axes with non-numeric entries
401+
:param coords: Xarray coords variable, which is a dict with axis: [axis values]
402+
:return: Whether all axes are linear, or not.
403+
"""
404+
linear = True
405+
for coord, values in coords.items():
406+
try:
407+
diff = numpy.diff(coords)
408+
if len(numpy.unique(diff)) > 1:
409+
warnings.warn(f'Axis {coord} is not linear')
410+
linear = False
411+
except TypeError:
412+
continue
413+
return linear
414+
415+
def _get_origin(self, axis):
416+
"""
417+
Get the coordinate origin of an axis, assuming it is the first entry.
418+
:param axis: A 1D list like entry accessible with indexing, which contains the axis coordinates
419+
:return: The origin for this axis.
420+
"""
421+
return axis.values[0]
422+
423+
def _get_scale(self, axis):
424+
"""
425+
Get the scale of an axis, assuming it is linear and so the scale is simply second - first coordinate.
426+
:param axis: A 1D list like entry accessible with indexing, which contains the axis coordinates
427+
:return: The scale for this axis or None if it is a non-numeric scale.
428+
"""
429+
try:
430+
return axis.values[1] - axis.values[0]
431+
except TypeError:
432+
return None
433+
434+
def _java_to_dataset(self, data):
296435
"""
297436
Converts the data into a ImageJ Dataset
298437
"""
299438
try:
300439
if self._ij.convert().supports(data, Dataset):
301440
return self._ij.convert().convert(data, Dataset)
441+
if self._ij.convert().supports(data, ImgPlus):
442+
imgPlus = self._ij.convert().convert(data, ImgPlus)
443+
return self._ij.dataset().create(imgPlus)
444+
if self._ij.convert().supports(data, Img):
445+
img = self._ij.convert().convert(data, Img)
446+
return self._ij.dataset().create(ImgPlus(img))
447+
if self._ij.convert().supports(data, RandomAccessibleInterval):
448+
rai = self._ij.convert().convert(data, RandomAccessibleInterval)
449+
return self._ij.dataset().create(rai)
302450
except Exception as exc:
303451
_dump_exception(exc)
304452
raise exc
@@ -308,11 +456,14 @@ def from_java(self, data):
308456
"""
309457
Converts the data into a python equivalent
310458
"""
459+
# todo: convert a datset to xarray
460+
311461
if not isjava(data): return data
312462
try:
313463
if self._ij.convert().supports(data, Dataset):
314464
# HACK: Converter exists for ImagePlus -> Dataset, but not ImagePlus -> RAI.
315465
data = self._ij.convert().convert(data, Dataset)
466+
return self._dataset_to_xarray(data)
316467
if (self._ij.convert().supports(data, RandomAccessibleInterval)):
317468
rai = self._ij.convert().convert(data, RandomAccessibleInterval)
318469
return self.rai_to_numpy(rai)
@@ -321,6 +472,37 @@ def from_java(self, data):
321472
raise exc
322473
return to_python(data)
323474

475+
def _dataset_to_xarray(self, dataset):
476+
"""
477+
Converts an ImageJ dataset into an xarray
478+
:param dataset: ImageJ dataset
479+
:return: xarray with reversed (C-style) dims and coords as labeled by the dataset
480+
"""
481+
attrs = self._ij.py.from_java(dataset.getProperties())
482+
axes = [(cast('net.imagej.axis.DefaultLinearAxis', dataset.axis(idx)))
483+
for idx in range(dataset.numDimensions())]
484+
485+
dims = [self._ijdim_to_pydim(axes[idx].type().getLabel()) for idx in range(len(axes))]
486+
values = self.rai_to_numpy(dataset)
487+
coords = self._get_axes_coords(axes, dims, numpy.shape(numpy.transpose(values)))
488+
489+
xarr = xr.DataArray(values, dims=list(reversed(dims)), coords=coords, attrs=attrs)
490+
return xarr
491+
492+
def _get_axes_coords(self, axes, dims, shape):
493+
"""
494+
Get xarray style coordinate list dictionary from a dataset
495+
:param axes: List of ImageJ axes
496+
:param dims: List of axes labels for each dataset axis
497+
:param shape: F-style, or reversed C-style, shape of axes numpy array.
498+
:return: Dictionary of coordinates for each axis.
499+
"""
500+
coords = {dims[idx]: numpy.arange(axes[idx].origin(), shape[idx]*axes[idx].scale() + axes[idx].origin(),
501+
axes[idx].scale())
502+
for idx in range(len(dims))}
503+
return coords
504+
505+
324506
def show(self, image, cmap=None):
325507
"""
326508
Display a java or python 2D image.
@@ -350,6 +532,12 @@ def _is_memoryarraylike(self, arr):
350532
hasattr(arr, 'data') and \
351533
type(arr.data).__name__ == 'memoryview'
352534

535+
def _is_xarraylike(self, xarr):
536+
return hasattr(xarr, 'values') and \
537+
hasattr(xarr, 'dims') and \
538+
hasattr(xarr, 'coords') and \
539+
self._is_arraylike(xarr.values)
540+
353541
def _assemble_plugin_macro(self, plugin: str, args=None, ij1_style=True):
354542
"""
355543
Assemble an ImageJ macro string given a plugin to run and optional arguments in a dict

Diff for: test/test_imagej.py

+58-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
ij = imagej.init(ij_dir)
1717

1818

19-
from jnius import autoclass
19+
from jnius import autoclass, cast
2020
import numpy as np
21-
21+
import xarray as xr
2222

2323
class TestImageJ(unittest.TestCase):
2424

@@ -107,6 +107,62 @@ def main(self):
107107
unittest.main()
108108

109109

110+
class TestXarrayConversion(unittest.TestCase):
111+
def testCstyleArrayWithLabeledDimsConverts(self):
112+
xarr = xr.DataArray(np.random.rand(5, 4, 3, 6, 12), dims=['T', 'Z', 'C', 'Y', 'X'],
113+
coords={'X': range(0, 12), 'Y': np.arange(0, 12, 2), 'C': ['R', 'G', 'B'],
114+
'Z': np.arange(10, 50, 10), 'T': np.arange(0, 0.05, 0.01)},
115+
attrs={'Hello': 'Wrld'})
116+
117+
dataset = ij.py.to_java(xarr)
118+
axes = [cast('net.imagej.axis.DefaultLinearAxis', dataset.axis(axnum)) for axnum in range(5)]
119+
labels = [axis.type().getLabel() for axis in axes]
120+
origins = [axis.origin() for axis in axes]
121+
scales = [axis.scale() for axis in axes]
122+
123+
self.assertListEqual(origins, [0, 0, 0, 10, 0])
124+
self.assertListEqual(scales, [1, 2, 1, 10, 0.01])
125+
126+
self.assertListEqual(list(reversed(xarr.dims)), labels)
127+
128+
self.assertEqual(xarr.attrs, ij.py.from_java(dataset.getProperties()))
129+
130+
def testFstyleArrayWiathLabeledDimsConverts(self):
131+
xarr = xr.DataArray(np.ndarray([5, 4, 3, 6, 12], order='F'), dims=['t', 'z', 'c', 'y', 'x'],
132+
coords={'x': range(0, 12), 'y': np.arange(0, 12, 2),
133+
'z': np.arange(10, 50, 10), 't': np.arange(0, 0.05, 0.01)},
134+
attrs={'Hello': 'Wrld'})
135+
136+
dataset = ij.py.to_java(xarr)
137+
axes = [cast('net.imagej.axis.DefaultLinearAxis', dataset.axis(axnum)) for axnum in range(5)]
138+
labels = [axis.type().getLabel() for axis in axes]
139+
origins = [axis.origin() for axis in axes]
140+
scales = [axis.scale() for axis in axes]
141+
142+
self.assertListEqual(origins, [0, 10, 0, 0, 0])
143+
self.assertListEqual(scales, [0.01, 10, 1, 2, 1])
144+
145+
self.assertListEqual([dim.upper() for dim in xarr.dims], labels)
146+
self.assertEqual(xarr.attrs, ij.py.from_java(dataset.getProperties()))
147+
148+
def testDatasetConvertsToXarray(self):
149+
xarr = xr.DataArray(np.random.rand(5, 4, 3, 6, 12), dims=['t', 'z', 'c', 'y', 'x'],
150+
coords={'x': list(range(0, 12)), 'y': list(np.arange(0, 12, 2)), 'c': [0, 1, 2],
151+
'z': list(np.arange(10, 50, 10)), 't': list(np.arange(0, 0.05, 0.01))},
152+
attrs={'Hello': 'Wrld'})
153+
154+
dataset = ij.py.to_java(xarr)
155+
156+
invert_xarr = ij.py.from_java(dataset)
157+
self.assertTrue((xarr.values == invert_xarr.values).all())
158+
159+
self.assertEqual(list(xarr.dims), list(invert_xarr.dims))
160+
for key in xarr.coords:
161+
self.assertTrue((xarr.coords[key] == invert_xarr.coords[key]).all())
162+
self.assertEqual(xarr.attrs, invert_xarr.attrs)
163+
164+
165+
110166
if __name__ == '__main__':
111167
unittest.main()
112168

0 commit comments

Comments
 (0)