Skip to content

Commit 717229d

Browse files
committed
Refactor code
@elevans and I talked about this in person. I don't see much reason to remove EnumeratedAxis support, since it isn't hurting anything. Maybe we can just exert more pressure towards using DefaultLinearAxis, and we can use NumPy to check dimension linearity!
1 parent 2930c20 commit 717229d

File tree

2 files changed

+33
-148
lines changed

2 files changed

+33
-148
lines changed

src/imagej/_java.py

-48
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ def MetadataWrapper(self):
5454
def LabelingIOService(self):
5555
return "io.scif.labeling.LabelingIOService"
5656

57-
@JavaClasses.java_import
58-
def ChapmanRichardsAxis(self):
59-
return "net.imagej.axis.ChapmanRichardsAxis"
60-
6157
@JavaClasses.java_import
6258
def DefaultLinearAxis(self):
6359
return "net.imagej.axis.DefaultLinearAxis"
@@ -66,50 +62,6 @@ def DefaultLinearAxis(self):
6662
def EnumeratedAxis(self):
6763
return "net.imagej.axis.EnumeratedAxis"
6864

69-
@JavaClasses.java_import
70-
def ExponentialAxis(self):
71-
return "net.imagej.axis.ExponentialAxis"
72-
73-
@JavaClasses.java_import
74-
def ExponentialRecoveryAxis(self):
75-
return "net.imagej.axis.ExponentialRecoveryAxis"
76-
77-
@JavaClasses.java_import
78-
def GammaVariateAxis(self):
79-
return "net.imagej.axis.GammaVariateAxis"
80-
81-
@JavaClasses.java_import
82-
def GaussianAxis(self):
83-
return "net.imagej.axis.GaussianAxis"
84-
85-
@JavaClasses.java_import
86-
def IdentityAxis(self):
87-
return "net.imagej.axis.IdentityAxis"
88-
89-
@JavaClasses.java_import
90-
def InverseRodbardAxis(self):
91-
return "net.imagej.axis.InverseRodbardAxis"
92-
93-
@JavaClasses.java_import
94-
def LogLinearAxis(self):
95-
return "net.imagej.axis.LogLinearAxis"
96-
97-
@JavaClasses.java_import
98-
def PolynomialAxis(self):
99-
return "net.imagej.axis.PolynomialAxis"
100-
101-
@JavaClasses.java_import
102-
def PowerAxis(self):
103-
return "net.imagej.axis.PowerAxis"
104-
105-
@JavaClasses.java_import
106-
def RodbardAxis(self):
107-
return "net.imagej.axis.RodbardAxis"
108-
109-
@JavaClasses.java_import
110-
def VariableAxis(self):
111-
return "net.iamgej.axis.VariableAxis"
112-
11365
@JavaClasses.java_import
11466
def Dataset(self):
11567
return "net.imagej.Dataset"

src/imagej/dims.py

+33-100
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,14 @@ def _assign_axes(
183183
"""
184184
Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both
185185
DefaultLinearAxis and the newer EnumeratedAxis.
186+
187+
Note that, in many cases, there are small discrepancies between the coordinates.
188+
This can either be actually within the data, or it can be from floating point math
189+
errors. In this case, we delegate to numpy.isclose to tell us whether our
190+
coordinates are linear or not. If our coordinates are nonlinear, and the
191+
EnumeratedAxis type is available, we will use it. Otherwise, this function
192+
returns a DefaultLinearAxis.
193+
186194
:param xarr: xarray that holds the data.
187195
:return: A list of ImageJ Axis with the specified origin and scale.
188196
"""
@@ -191,41 +199,37 @@ def _assign_axes(
191199
axis_str = _convert_dim(dim, "java")
192200
ax_type = jc.Axes.get(axis_str)
193201
ax_num = _get_axis_num(xarr, dim)
194-
coords_arr = xarr.coords[dim].to_numpy()
202+
coords_arr = xarr.coords[dim].to_numpy().astype(np.double)
195203

196-
# check if coords/scale is numeric
197-
if _is_numeric_scale(coords_arr):
198-
doub_coords = [jc.Double(np.double(x)) for x in xarr.coords[dim]]
199-
else:
204+
# coerce numeric scale
205+
if not _is_numeric_scale(coords_arr):
200206
_logger.warning(
201207
f"The {ax_type.label} axis is non-numeric and is translated "
202208
"to a linear index."
203209
)
204-
doub_coords = [
205-
jc.Double(np.double(x)) for x in np.arrange(len(xarr.coords[dim]))
206-
]
207-
208-
# assign calibrated axis type -- checks for imagej metadata
209-
if "imagej" in xarr.attrs.keys():
210-
ij_dim = _convert_dim(dim, "java")
211-
if ij_dim + "_cal_axis_type" in xarr.attrs["imagej"].keys():
212-
scale_type = xarr.attrs["imagej"][ij_dim + "_cal_axis_type"]
213-
if scale_type == "linear":
214-
jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords))
215-
if scale_type == "enumerated":
216-
try:
217-
EnumeratedAxis = _get_enumerated_axis()
218-
except (JException, TypeError):
219-
EnumeratedAxis = None
220-
if EnumeratedAxis is not None:
221-
jaxis = EnumeratedAxis(ax_type, sj.to_java(doub_coords))
222-
else:
223-
jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords))
210+
coords_arr = [np.double(x) for x in np.arrange(len(xarr.coords[dim]))]
211+
212+
# check scale linearity
213+
diffs = np.diff(coords_arr)
214+
linear: bool = diffs.size and np.all(np.isclose(diffs, diffs[0]))
215+
216+
# For non-linear scales, use EnumeratedAxis
217+
try:
218+
EnumeratedAxis = sj.jimport("net.imagej.axis.EnumeratedAxis")
219+
except (JException, TypeError):
220+
EnumeratedAxis = None
221+
# If we can use EnumeratedAxis for a nonlinear scale, then use it
222+
if not linear and EnumeratedAxis:
223+
j_coords = [jc.Double(x) for x in coords_arr]
224+
axes[ax_num] = EnumeratedAxis(ax_type, sj.to_java(j_coords))
225+
# Otherwise, use DefaultLinearAxis
224226
else:
225-
# default to DefaultLinearAxis always if no `scale_type` key in attr
226-
jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords))
227-
228-
axes[ax_num] = jaxis
227+
DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis")
228+
scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1
229+
origin = coords_arr[0] if len(coords_arr) > 0 else 0
230+
axes[ax_num] = DefaultLinearAxis(
231+
ax_type, jc.Double(scale), jc.Double(origin)
232+
)
229233

230234
return axes
231235

@@ -280,27 +284,6 @@ def _get_axes_coords(
280284
return coords
281285

282286

283-
def _get_scale(axis):
284-
"""
285-
Get the scale of an axis, assuming it is linear and so the scale is simply
286-
second - first coordinate.
287-
288-
:param axis: A 1D list like entry accessible with indexing, which contains the
289-
axis coordinates
290-
:return: The scale for this axis or None if it is a non-numeric scale.
291-
"""
292-
try:
293-
# HACK: This axis length check is a work around for singleton dimensions.
294-
# You can't calculate the slope of a singleton dimension.
295-
# This section will be removed when axis-scale-logic is merged.
296-
if len(axis) <= 1:
297-
return 1
298-
else:
299-
return axis.values[1] - axis.values[0]
300-
except TypeError:
301-
return None
302-
303-
304287
def _is_numeric_scale(coords_array: np.ndarray) -> bool:
305288
"""
306289
Checks if the coordinates array of the given axis is numeric.
@@ -311,29 +294,6 @@ def _is_numeric_scale(coords_array: np.ndarray) -> bool:
311294
return np.issubdtype(coords_array.dtype, np.number)
312295

313296

314-
def _get_enumerated_axis():
315-
"""Get EnumeratedAxis.
316-
317-
EnumeratedAxis is only in releases later than March 2020. If using
318-
an older version of ImageJ without EnumeratedAxis, use
319-
_get_linear_axis() instead.
320-
"""
321-
return sj.jimport("net.imagej.axis.EnumeratedAxis")
322-
323-
324-
def _get_linear_axis(axis_type: "jc.AxisType", values):
325-
"""Get linear axis.
326-
327-
This is used if no EnumeratedAxis is found. If EnumeratedAxis
328-
is available, use _get_enumerated_axis() instead.
329-
"""
330-
DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis")
331-
origin = values[0]
332-
scale = values[1] - values[0]
333-
axis = DefaultLinearAxis(axis_type, scale, origin)
334-
return axis
335-
336-
337297
def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus":
338298
"""Get an ImgPlus from a Dataset.
339299
@@ -483,30 +443,3 @@ def _to_ijdim(key: str) -> str:
483443
return ijdims[key]
484444
else:
485445
return key
486-
487-
488-
def _cal_axis_type_to_str(key) -> str:
489-
"""
490-
Convert a CalibratedAxis type (e.g. net.imagej.axis.DefaultLinearAxis) to
491-
a string.
492-
"""
493-
cal_axis_types = {
494-
jc.ChapmanRichardsAxis: "ChapmanRichardsAxis",
495-
jc.DefaultLinearAxis: "DefaultLinearAxis",
496-
jc.EnumeratedAxis: "EnumeratedAxis",
497-
jc.ExponentialAxis: "ExponentialAxis",
498-
jc.ExponentialRecoveryAxis: "ExponentialRecoveryAxis",
499-
jc.GammaVariateAxis: "GammaVariateAxis",
500-
jc.GaussianAxis: "GaussianAxis",
501-
jc.IdentityAxis: "IdentityAxis",
502-
jc.InverseRodbardAxis: "InverseRodbardAxis",
503-
jc.LogLinearAxis: "LogLinearAxis",
504-
jc.PolynomialAxis: "PolynomialAxis",
505-
jc.PowerAxis: "PowerAxis",
506-
jc.RodbardAxis: "RodbardAxis",
507-
}
508-
509-
if key.__class__ in cal_axis_types:
510-
return cal_axis_types[key.__class__]
511-
else:
512-
return "unknown"

0 commit comments

Comments
 (0)