13
13
import jnius_config
14
14
from pathlib import Path
15
15
import numpy
16
+ import xarray as xr
17
+ import warnings
16
18
17
19
_logger = logging .getLogger (__name__ )
18
20
@@ -124,6 +126,8 @@ def init(ij_dir_or_version_or_endpoint=None, headless=True, new_instance=False):
124
126
# Must import imglyb (not scyjava) to spin up the JVM now.
125
127
import imglyb
126
128
from jnius import autoclass
129
+ from jnius import cast
130
+ import scyjava
127
131
128
132
# Initialize ImageJ.
129
133
ImageJ = autoclass ('net.imagej.ImageJ' )
@@ -134,7 +138,11 @@ def init(ij_dir_or_version_or_endpoint=None, headless=True, new_instance=False):
134
138
from scyjava import jclass , isjava , to_java , to_python
135
139
136
140
Dataset = autoclass ('net.imagej.Dataset' )
141
+ ImgPlus = autoclass ('net.imagej.ImgPlus' )
142
+ Img = autoclass ('net.imglib2.img.Img' )
137
143
RandomAccessibleInterval = autoclass ('net.imglib2.RandomAccessibleInterval' )
144
+ Axes = autoclass ('net.imagej.axis.Axes' )
145
+ DefaultLinearAxis = autoclass ('net.imagej.axis.DefaultLinearAxis' )
138
146
139
147
class ImageJPython :
140
148
def __init__ (self , ij ):
@@ -286,19 +294,159 @@ def run_script(self, language, script, args=None):
286
294
287
295
def to_java (self , data ):
288
296
"""
289
- Converts the data into a java equivalent. For numpy arrays, the java image points to the python array
297
+ Converts the data into a java equivalent. For numpy arrays, the java image points to the python array.
298
+
299
+ In addition to the scyjava types, we allow ndarray-like and xarray-like variables
290
300
"""
291
301
if self ._is_memoryarraylike (data ):
292
302
return imglyb .to_imglib (data )
303
+ if self ._is_xarraylike (data ):
304
+ return self .to_dataset (data )
293
305
return to_java (data )
294
306
295
307
def to_dataset (self , data ):
308
+ """Converts the data into an ImageJ dataset"""
309
+ if self ._is_xarraylike (data ):
310
+ return self ._xarray_to_dataset (data )
311
+ if self ._is_arraylike (data ):
312
+ return self ._numpy_to_dataset (data )
313
+ if scyjava .isjava (data ):
314
+ return self ._java_to_dataset (data )
315
+
316
+ raise TypeError (f'Type not supported: { type (data )} ' )
317
+
318
+ def _numpy_to_dataset (self , data ):
319
+ rai = imglyb .to_imglib (data )
320
+ return self ._java_to_dataset (rai )
321
+
322
+ def _xarray_to_dataset (self , xarr ):
323
+ """
324
+ Converts a xarray dataarray with specified dim order to an image
325
+ :param xarr: Pass an xarray dataarray and turn into a dataset.
326
+ :return: The dataset
327
+ """
328
+ dataset = self ._numpy_to_dataset (xarr .values )
329
+ axes = self ._assign_axes (xarr )
330
+ dataset .setAxes (axes )
331
+
332
+ # Currently, we have no handling for nonlinear axes, but I thought it should warn instead of fail.
333
+ if not self ._axis_is_linear (xarr .coords ):
334
+ warnings .warn ("Not all axes are linear. The nonlinear axes are not mapped correctly." )
335
+
336
+ self ._assign_dataset_metadata (dataset , xarr .attrs )
337
+
338
+ return dataset
339
+
340
+ def _assign_axes (self , xarr ):
341
+ """
342
+ Obtain xarray axes names, origin, and scale and convert into ImageJ Axis; currently supports DefaultLinearAxis.
343
+ :param xarr: xarray that holds the units
344
+ :return: A list of ImageJ Axis with the specified origin and scale
345
+ """
346
+ axes = ['' ]* len (xarr .dims )
347
+
348
+ for axis in xarr .dims :
349
+ origin = self ._get_origin (xarr .coords [axis ])
350
+ scale = self ._get_scale (xarr .coords [axis ])
351
+
352
+ axisStr = self ._pydim_to_ijdim (axis )
353
+
354
+ ax_type = Axes .get (axisStr )
355
+ ax_num = self ._get_axis_num (xarr , axis )
356
+ if scale is None :
357
+ java_axis = DefaultLinearAxis (ax_type )
358
+ else :
359
+ java_axis = DefaultLinearAxis (ax_type , numpy .double (scale ), numpy .double (origin ))
360
+
361
+ axes [ax_num ] = java_axis
362
+
363
+ return axes
364
+
365
+ def _pydim_to_ijdim (self , axis ):
366
+ """Convert between the lowercase Python convention (x, y, z, c, t) to IJ (X, Y, Z, C, T)"""
367
+ if str (axis ) in ['x' , 'y' , 'z' , 'c' , 't' ]:
368
+ return str (axis ).upper ()
369
+ return str (axis )
370
+
371
+ def _ijdim_to_pydim (self , axis ):
372
+ """Convert the IJ uppercase dimension convention (X, Y, Z< C, T) to lowercase python (x, y, z, c, t) """
373
+ if str (axis ) in ['X' , 'Y' , 'Z' , 'C' , 'T' ]:
374
+ return str (axis ).lower ()
375
+ return str (axis )
376
+
377
+ def _get_axis_num (self , xarr , axis ):
378
+ """
379
+ Get the xarray -> java axis number due to inverted axis order for C style numpy arrays (default)
380
+ :param xarr: Xarray to convert
381
+ :param axis: Axis number to convert
382
+ :return: Axis idx in java
383
+ """
384
+ py_axnum = xarr .get_axis_num (axis )
385
+ if numpy .isfortran (xarr .values ):
386
+ return py_axnum
387
+
388
+ return len (xarr .dims ) - py_axnum - 1
389
+
390
+
391
+ def _assign_dataset_metadata (self , dataset , attrs ):
392
+ """
393
+ :param dataset: ImageJ Java dataset
394
+ :param attrs: Dictionary containing metadata
395
+ """
396
+ dataset .getProperties ().putAll (self .to_java (attrs ))
397
+
398
+ def _axis_is_linear (self , coords ):
399
+ """
400
+ Check if each axis has linear steps between grid points. Skip over axes with non-numeric entries
401
+ :param coords: Xarray coords variable, which is a dict with axis: [axis values]
402
+ :return: Whether all axes are linear, or not.
403
+ """
404
+ linear = True
405
+ for coord , values in coords .items ():
406
+ try :
407
+ diff = numpy .diff (coords )
408
+ if len (numpy .unique (diff )) > 1 :
409
+ warnings .warn (f'Axis { coord } is not linear' )
410
+ linear = False
411
+ except TypeError :
412
+ continue
413
+ return linear
414
+
415
+ def _get_origin (self , axis ):
416
+ """
417
+ Get the coordinate origin of an axis, assuming it is the first entry.
418
+ :param axis: A 1D list like entry accessible with indexing, which contains the axis coordinates
419
+ :return: The origin for this axis.
420
+ """
421
+ return axis .values [0 ]
422
+
423
+ def _get_scale (self , axis ):
424
+ """
425
+ Get the scale of an axis, assuming it is linear and so the scale is simply second - first coordinate.
426
+ :param axis: A 1D list like entry accessible with indexing, which contains the axis coordinates
427
+ :return: The scale for this axis or None if it is a non-numeric scale.
428
+ """
429
+ try :
430
+ return axis .values [1 ] - axis .values [0 ]
431
+ except TypeError :
432
+ return None
433
+
434
+ def _java_to_dataset (self , data ):
296
435
"""
297
436
Converts the data into a ImageJ Dataset
298
437
"""
299
438
try :
300
439
if self ._ij .convert ().supports (data , Dataset ):
301
440
return self ._ij .convert ().convert (data , Dataset )
441
+ if self ._ij .convert ().supports (data , ImgPlus ):
442
+ imgPlus = self ._ij .convert ().convert (data , ImgPlus )
443
+ return self ._ij .dataset ().create (imgPlus )
444
+ if self ._ij .convert ().supports (data , Img ):
445
+ img = self ._ij .convert ().convert (data , Img )
446
+ return self ._ij .dataset ().create (ImgPlus (img ))
447
+ if self ._ij .convert ().supports (data , RandomAccessibleInterval ):
448
+ rai = self ._ij .convert ().convert (data , RandomAccessibleInterval )
449
+ return self ._ij .dataset ().create (rai )
302
450
except Exception as exc :
303
451
_dump_exception (exc )
304
452
raise exc
@@ -308,11 +456,14 @@ def from_java(self, data):
308
456
"""
309
457
Converts the data into a python equivalent
310
458
"""
459
+ # todo: convert a datset to xarray
460
+
311
461
if not isjava (data ): return data
312
462
try :
313
463
if self ._ij .convert ().supports (data , Dataset ):
314
464
# HACK: Converter exists for ImagePlus -> Dataset, but not ImagePlus -> RAI.
315
465
data = self ._ij .convert ().convert (data , Dataset )
466
+ return self ._dataset_to_xarray (data )
316
467
if (self ._ij .convert ().supports (data , RandomAccessibleInterval )):
317
468
rai = self ._ij .convert ().convert (data , RandomAccessibleInterval )
318
469
return self .rai_to_numpy (rai )
@@ -321,6 +472,37 @@ def from_java(self, data):
321
472
raise exc
322
473
return to_python (data )
323
474
475
+ def _dataset_to_xarray (self , dataset ):
476
+ """
477
+ Converts an ImageJ dataset into an xarray
478
+ :param dataset: ImageJ dataset
479
+ :return: xarray with reversed (C-style) dims and coords as labeled by the dataset
480
+ """
481
+ attrs = self ._ij .py .from_java (dataset .getProperties ())
482
+ axes = [(cast ('net.imagej.axis.DefaultLinearAxis' , dataset .axis (idx )))
483
+ for idx in range (dataset .numDimensions ())]
484
+
485
+ dims = [self ._ijdim_to_pydim (axes [idx ].type ().getLabel ()) for idx in range (len (axes ))]
486
+ values = self .rai_to_numpy (dataset )
487
+ coords = self ._get_axes_coords (axes , dims , numpy .shape (numpy .transpose (values )))
488
+
489
+ xarr = xr .DataArray (values , dims = list (reversed (dims )), coords = coords , attrs = attrs )
490
+ return xarr
491
+
492
+ def _get_axes_coords (self , axes , dims , shape ):
493
+ """
494
+ Get xarray style coordinate list dictionary from a dataset
495
+ :param axes: List of ImageJ axes
496
+ :param dims: List of axes labels for each dataset axis
497
+ :param shape: F-style, or reversed C-style, shape of axes numpy array.
498
+ :return: Dictionary of coordinates for each axis.
499
+ """
500
+ coords = {dims [idx ]: numpy .arange (axes [idx ].origin (), shape [idx ]* axes [idx ].scale () + axes [idx ].origin (),
501
+ axes [idx ].scale ())
502
+ for idx in range (len (dims ))}
503
+ return coords
504
+
505
+
324
506
def show (self , image , cmap = None ):
325
507
"""
326
508
Display a java or python 2D image.
@@ -350,6 +532,12 @@ def _is_memoryarraylike(self, arr):
350
532
hasattr (arr , 'data' ) and \
351
533
type (arr .data ).__name__ == 'memoryview'
352
534
535
+ def _is_xarraylike (self , xarr ):
536
+ return hasattr (xarr , 'values' ) and \
537
+ hasattr (xarr , 'dims' ) and \
538
+ hasattr (xarr , 'coords' ) and \
539
+ self ._is_arraylike (xarr .values )
540
+
353
541
def _assemble_plugin_macro (self , plugin : str , args = None , ij1_style = True ):
354
542
"""
355
543
Assemble an ImageJ macro string given a plugin to run and optional arguments in a dict
0 commit comments