Skip to content

Commit 9f2feb7

Browse files
authored
Merge pull request #265 from imagej/linear-axis-logic
Simplify linear axis assignment logic
2 parents 6968b63 + 68a77ee commit 9f2feb7

File tree

3 files changed

+160
-67
lines changed

3 files changed

+160
-67
lines changed

src/imagej/_java.py

+12
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class MyJavaClasses(JavaClasses):
3030
significantly easier and more readable.
3131
"""
3232

33+
@JavaClasses.java_import
34+
def Double(self):
35+
return "java.lang.Double"
36+
3337
@JavaClasses.java_import
3438
def Throwable(self):
3539
return "java.lang.Throwable"
@@ -50,6 +54,14 @@ def MetadataWrapper(self):
5054
def LabelingIOService(self):
5155
return "io.scif.labeling.LabelingIOService"
5256

57+
@JavaClasses.java_import
58+
def DefaultLinearAxis(self):
59+
return "net.imagej.axis.DefaultLinearAxis"
60+
61+
@JavaClasses.java_import
62+
def EnumeratedAxis(self):
63+
return "net.imagej.axis.EnumeratedAxis"
64+
5365
@JavaClasses.java_import
5466
def Dataset(self):
5567
return "net.imagej.Dataset"

src/imagej/dims.py

+49-67
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Utility functions for querying and manipulating dimensional axis metadata.
33
"""
44
import logging
5-
from typing import List, Tuple
5+
from typing import List, Tuple, Union
66

77
import numpy as np
88
import scyjava as sj
@@ -177,49 +177,53 @@ def prioritize_rai_axes_order(
177177
return permute_order
178178

179179

180-
def _assign_axes(xarr: xr.DataArray):
180+
def _assign_axes(
181+
xarr: xr.DataArray,
182+
) -> List[Union["jc.DefaultLinearAxis", "jc.EnumeratedAxis"]]:
181183
"""
182-
Obtain xarray axes names, origin, and scale and convert into ImageJ Axis;
183-
currently supports EnumeratedAxis
184-
:param xarr: xarray that holds the units
185-
:return: A list of ImageJ Axis with the specified origin and scale
184+
Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both
185+
DefaultLinearAxis and the newer EnumeratedAxis.
186+
187+
Note that, in many cases, there are small discrepancies between the coordinates.
188+
This can either be actually within the data, or it can be from floating point math
189+
errors. In this case, we delegate to numpy.isclose to tell us whether our
190+
coordinates are linear or not. If our coordinates are nonlinear, and the
191+
EnumeratedAxis type is available, we will use it. Otherwise, this function
192+
returns a DefaultLinearAxis.
193+
194+
:param xarr: xarray that holds the data.
195+
:return: A list of ImageJ Axis with the specified origin and scale.
186196
"""
187-
Double = sj.jimport("java.lang.Double")
188-
189-
axes = [""] * len(xarr.dims)
190-
191-
# try to get EnumeratedAxis, if not then default to LinearAxis in the loop
192-
try:
193-
EnumeratedAxis = _get_enumerated_axis()
194-
except (JException, TypeError):
195-
EnumeratedAxis = None
196-
197+
axes = [""] * xarr.ndim
197198
for dim in xarr.dims:
198-
axis_str = _convert_dim(dim, direction="java")
199+
axis_str = _convert_dim(dim, "java")
199200
ax_type = jc.Axes.get(axis_str)
200201
ax_num = _get_axis_num(xarr, dim)
201-
scale = _get_scale(xarr.coords[dim])
202+
coords_arr = xarr.coords[dim]
202203

203-
if scale is None:
204+
# coerce numeric scale
205+
if not _is_numeric_scale(coords_arr):
204206
_logger.warning(
205-
f"The {ax_type.label} axis is non-numeric and is translated "
207+
f"The {ax_type.getLabel()} axis is non-numeric and is translated "
206208
"to a linear index."
207209
)
208-
doub_coords = [
209-
Double(np.double(x)) for x in np.arange(len(xarr.coords[dim]))
210-
]
210+
coords_arr = [np.double(x) for x in np.arange(len(xarr.coords[dim]))]
211211
else:
212-
doub_coords = [Double(np.double(x)) for x in xarr.coords[dim]]
213-
214-
# EnumeratedAxis is a new axis made for xarray, so is only present in
215-
# ImageJ versions that are released later than March 2020.
216-
# This actually returns a LinearAxis if using an earlier version.
217-
if EnumeratedAxis is not None:
218-
java_axis = EnumeratedAxis(ax_type, sj.to_java(doub_coords))
212+
coords_arr = coords_arr.to_numpy().astype(np.double)
213+
214+
# check scale linearity
215+
diffs = np.diff(coords_arr)
216+
linear: bool = diffs.size and np.all(np.isclose(diffs, diffs[0]))
217+
218+
if not linear:
219+
try:
220+
j_coords = [jc.Double(x) for x in coords_arr]
221+
axes[ax_num] = jc.EnumeratedAxis(ax_type, sj.to_java(j_coords))
222+
except (JException, TypeError):
223+
# if EnumeratedAxis not available - use DefaultLinearAxis
224+
axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type)
219225
else:
220-
java_axis = _get_linear_axis(ax_type, sj.to_java(doub_coords))
221-
222-
axes[ax_num] = java_axis
226+
axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type)
223227

224228
return axes
225229

@@ -274,48 +278,26 @@ def _get_axes_coords(
274278
return coords
275279

276280

277-
def _get_scale(axis):
281+
def _get_default_linear_axis(coords_arr: np.ndarray, ax_type: "jc.AxisType"):
278282
"""
279-
Get the scale of an axis, assuming it is linear and so the scale is simply
280-
second - first coordinate.
283+
Create a new DefaultLinearAxis with the given coordinate array and axis type.
281284
282-
:param axis: A 1D list like entry accessible with indexing, which contains the
283-
axis coordinates
284-
:return: The scale for this axis or None if it is a non-numeric scale.
285+
:param coords_arr: A 1D NumPy array.
286+
:return: An instance of net.imagej.axis.DefaultLinearAxis.
285287
"""
286-
try:
287-
# HACK: This axis length check is a work around for singleton dimensions.
288-
# You can't calculate the slope of a singleton dimension.
289-
# This section will be removed when axis-scale-logic is merged.
290-
if len(axis) <= 1:
291-
return 1
292-
else:
293-
return axis.values[1] - axis.values[0]
294-
except TypeError:
295-
return None
296-
288+
scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1
289+
origin = coords_arr[0] if len(coords_arr) > 0 else 0
290+
return jc.DefaultLinearAxis(ax_type, jc.Double(scale), jc.Double(origin))
297291

298-
def _get_enumerated_axis():
299-
"""Get EnumeratedAxis.
300292

301-
EnumeratedAxis is only in releases later than March 2020. If using
302-
an older version of ImageJ without EnumeratedAxis, use
303-
_get_linear_axis() instead.
293+
def _is_numeric_scale(coords_array: np.ndarray) -> bool:
304294
"""
305-
return sj.jimport("net.imagej.axis.EnumeratedAxis")
306-
307-
308-
def _get_linear_axis(axis_type: "jc.AxisType", values):
309-
"""Get linear axis.
295+
Checks if the coordinates array of the given axis is numeric.
310296
311-
This is used if no EnumeratedAxis is found. If EnumeratedAxis
312-
is available, use _get_enumerated_axis() instead.
297+
:param coords_array: A 1D NumPy array.
298+
:return: bool
313299
"""
314-
DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis")
315-
origin = values[0]
316-
scale = values[1] - values[0]
317-
axis = DefaultLinearAxis(axis_type, scale, origin)
318-
return axis
300+
return np.issubdtype(coords_array.dtype, np.number)
319301

320302

321303
def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus":

tests/test_image_conversion.py

+99
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
import string
23

34
import numpy as np
45
import pytest
@@ -7,6 +8,7 @@
78

89
import imagej.dims as dims
910
import imagej.images as images
11+
from imagej._java import jc
1012

1113
# -- Image helpers --
1214

@@ -94,6 +96,75 @@ def get_xarr(option="C"):
9496
return xarr
9597

9698

99+
def get_non_linear_coord_xarr(option="C"):
100+
name: str = "non_linear_coord_data_array"
101+
linear_coord_arr = np.arange(5)
102+
# generate a 1D log scale array
103+
non_linear_coord_arr = np.logspace(0, np.log10(100), num=30)
104+
if option == "C":
105+
xarr = xr.DataArray(
106+
np.random.rand(30, 30, 5),
107+
dims=["row", "col", "ch"],
108+
coords={
109+
"row": non_linear_coord_arr,
110+
"col": non_linear_coord_arr,
111+
"ch": linear_coord_arr,
112+
},
113+
attrs={"Hello": "World"},
114+
name=name,
115+
)
116+
elif option == "F":
117+
xarr = xr.DataArray(
118+
np.ndarray([30, 30, 5], order="F"),
119+
dims=["row", "col", "ch"],
120+
coords={
121+
"row": non_linear_coord_arr,
122+
"col": non_linear_coord_arr,
123+
"ch": linear_coord_arr,
124+
},
125+
attrs={"Hello": "World"},
126+
name=name,
127+
)
128+
else:
129+
xarr = xr.DataArray(np.random.rand(30, 30, 5), name=name)
130+
131+
return xarr
132+
133+
134+
def get_non_numeric_coord_xarr(option="C"):
135+
name: str = "non_numeric_coord_data_array"
136+
non_numeric_coord_list = [random.choice(string.ascii_letters) for _ in range(30)]
137+
linear_coord_arr = np.arange(5)
138+
if option == "C":
139+
xarr = xr.DataArray(
140+
np.random.rand(30, 30, 5),
141+
dims=["row", "col", "ch"],
142+
coords={
143+
"row": non_numeric_coord_list,
144+
"col": non_numeric_coord_list,
145+
"ch": linear_coord_arr,
146+
},
147+
attrs={"Hello": "World"},
148+
name=name,
149+
)
150+
elif option == "F":
151+
xarr = xr.DataArray(
152+
np.ndarray([30, 30, 5], order="F"),
153+
dims=["row", "col", "ch"],
154+
coords={
155+
"row": non_numeric_coord_list,
156+
"col": non_numeric_coord_list,
157+
"ch": linear_coord_arr,
158+
},
159+
attrs={"Hello": "World"},
160+
name=name,
161+
)
162+
else:
163+
xarr = xr.DataArray(np.random.rand(30, 30, 5), name=name)
164+
165+
return xarr
166+
167+
97168
# -- Helpers --
98169

99170

@@ -359,6 +430,34 @@ def test_no_coords_or_dims_in_xarr(ij_fixture):
359430
assert_inverted_xarr_equal_to_xarr(dataset, ij_fixture, xarr)
360431

361432

433+
def test_linear_coord_on_xarr_conversion(ij_fixture):
434+
xarr = get_xarr()
435+
dataset = ij_fixture.py.to_java(xarr)
436+
axes = dataset.dim_axes
437+
# all axes should be DefaultLinearAxis
438+
for ax in axes:
439+
assert isinstance(ax, jc.DefaultLinearAxis)
440+
441+
442+
def test_non_linear_coord_on_xarr_conversion(ij_fixture):
443+
xarr = get_non_linear_coord_xarr()
444+
dataset = ij_fixture.py.to_java(xarr)
445+
axes = dataset.dim_axes
446+
# axes [0, 1] should be EnumeratedAxis with axis 2 as DefaultLinearAxis
447+
for i in range(2):
448+
assert isinstance(axes[i], jc.EnumeratedAxis)
449+
assert isinstance(axes[-1], jc.DefaultLinearAxis)
450+
451+
452+
def test_non_numeric_coord_on_xarr_conversion(ij_fixture):
453+
xarr = get_non_numeric_coord_xarr()
454+
dataset = ij_fixture.py.to_java(xarr)
455+
axes = dataset.dim_axes
456+
# all axes should be DefaultLinearAxis
457+
for ax in axes:
458+
assert isinstance(ax, jc.DefaultLinearAxis)
459+
460+
362461
dataset_conversion_parameters = [
363462
(
364463
get_img,

0 commit comments

Comments
 (0)