|
2 | 2 | Utility functions for querying and manipulating dimensional axis metadata.
|
3 | 3 | """
|
4 | 4 | import logging
|
5 |
| -from typing import List, Tuple |
| 5 | +from typing import List, Tuple, Union |
6 | 6 |
|
7 | 7 | import numpy as np
|
8 | 8 | import scyjava as sj
|
@@ -177,49 +177,53 @@ def prioritize_rai_axes_order(
|
177 | 177 | return permute_order
|
178 | 178 |
|
179 | 179 |
|
180 |
| -def _assign_axes(xarr: xr.DataArray): |
| 180 | +def _assign_axes( |
| 181 | + xarr: xr.DataArray, |
| 182 | +) -> List[Union["jc.DefaultLinearAxis", "jc.EnumeratedAxis"]]: |
181 | 183 | """
|
182 |
| - Obtain xarray axes names, origin, and scale and convert into ImageJ Axis; |
183 |
| - currently supports EnumeratedAxis |
184 |
| - :param xarr: xarray that holds the units |
185 |
| - :return: A list of ImageJ Axis with the specified origin and scale |
| 184 | + Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both |
| 185 | + DefaultLinearAxis and the newer EnumeratedAxis. |
| 186 | +
|
| 187 | + Note that, in many cases, there are small discrepancies between the coordinates. |
| 188 | + This can either be actually within the data, or it can be from floating point math |
| 189 | + errors. In this case, we delegate to numpy.isclose to tell us whether our |
| 190 | + coordinates are linear or not. If our coordinates are nonlinear, and the |
| 191 | + EnumeratedAxis type is available, we will use it. Otherwise, this function |
| 192 | + returns a DefaultLinearAxis. |
| 193 | +
|
| 194 | + :param xarr: xarray that holds the data. |
| 195 | + :return: A list of ImageJ Axis with the specified origin and scale. |
186 | 196 | """
|
187 |
| - Double = sj.jimport("java.lang.Double") |
188 |
| - |
189 |
| - axes = [""] * len(xarr.dims) |
190 |
| - |
191 |
| - # try to get EnumeratedAxis, if not then default to LinearAxis in the loop |
192 |
| - try: |
193 |
| - EnumeratedAxis = _get_enumerated_axis() |
194 |
| - except (JException, TypeError): |
195 |
| - EnumeratedAxis = None |
196 |
| - |
| 197 | + axes = [""] * xarr.ndim |
197 | 198 | for dim in xarr.dims:
|
198 |
| - axis_str = _convert_dim(dim, direction="java") |
| 199 | + axis_str = _convert_dim(dim, "java") |
199 | 200 | ax_type = jc.Axes.get(axis_str)
|
200 | 201 | ax_num = _get_axis_num(xarr, dim)
|
201 |
| - scale = _get_scale(xarr.coords[dim]) |
| 202 | + coords_arr = xarr.coords[dim] |
202 | 203 |
|
203 |
| - if scale is None: |
| 204 | + # coerce numeric scale |
| 205 | + if not _is_numeric_scale(coords_arr): |
204 | 206 | _logger.warning(
|
205 |
| - f"The {ax_type.label} axis is non-numeric and is translated " |
| 207 | + f"The {ax_type.getLabel()} axis is non-numeric and is translated " |
206 | 208 | "to a linear index."
|
207 | 209 | )
|
208 |
| - doub_coords = [ |
209 |
| - Double(np.double(x)) for x in np.arange(len(xarr.coords[dim])) |
210 |
| - ] |
| 210 | + coords_arr = [np.double(x) for x in np.arange(len(xarr.coords[dim]))] |
211 | 211 | else:
|
212 |
| - doub_coords = [Double(np.double(x)) for x in xarr.coords[dim]] |
213 |
| - |
214 |
| - # EnumeratedAxis is a new axis made for xarray, so is only present in |
215 |
| - # ImageJ versions that are released later than March 2020. |
216 |
| - # This actually returns a LinearAxis if using an earlier version. |
217 |
| - if EnumeratedAxis is not None: |
218 |
| - java_axis = EnumeratedAxis(ax_type, sj.to_java(doub_coords)) |
| 212 | + coords_arr = coords_arr.to_numpy().astype(np.double) |
| 213 | + |
| 214 | + # check scale linearity |
| 215 | + diffs = np.diff(coords_arr) |
| 216 | + linear: bool = diffs.size and np.all(np.isclose(diffs, diffs[0])) |
| 217 | + |
| 218 | + if not linear: |
| 219 | + try: |
| 220 | + j_coords = [jc.Double(x) for x in coords_arr] |
| 221 | + axes[ax_num] = jc.EnumeratedAxis(ax_type, sj.to_java(j_coords)) |
| 222 | + except (JException, TypeError): |
| 223 | + # if EnumeratedAxis not available - use DefaultLinearAxis |
| 224 | + axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type) |
219 | 225 | else:
|
220 |
| - java_axis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) |
221 |
| - |
222 |
| - axes[ax_num] = java_axis |
| 226 | + axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type) |
223 | 227 |
|
224 | 228 | return axes
|
225 | 229 |
|
@@ -274,48 +278,26 @@ def _get_axes_coords(
|
274 | 278 | return coords
|
275 | 279 |
|
276 | 280 |
|
277 |
| -def _get_scale(axis): |
| 281 | +def _get_default_linear_axis(coords_arr: np.ndarray, ax_type: "jc.AxisType"): |
278 | 282 | """
|
279 |
| - Get the scale of an axis, assuming it is linear and so the scale is simply |
280 |
| - second - first coordinate. |
| 283 | + Create a new DefaultLinearAxis with the given coordinate array and axis type. |
281 | 284 |
|
282 |
| - :param axis: A 1D list like entry accessible with indexing, which contains the |
283 |
| - axis coordinates |
284 |
| - :return: The scale for this axis or None if it is a non-numeric scale. |
| 285 | + :param coords_arr: A 1D NumPy array. |
| 286 | + :return: An instance of net.imagej.axis.DefaultLinearAxis. |
285 | 287 | """
|
286 |
| - try: |
287 |
| - # HACK: This axis length check is a work around for singleton dimensions. |
288 |
| - # You can't calculate the slope of a singleton dimension. |
289 |
| - # This section will be removed when axis-scale-logic is merged. |
290 |
| - if len(axis) <= 1: |
291 |
| - return 1 |
292 |
| - else: |
293 |
| - return axis.values[1] - axis.values[0] |
294 |
| - except TypeError: |
295 |
| - return None |
296 |
| - |
| 288 | + scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1 |
| 289 | + origin = coords_arr[0] if len(coords_arr) > 0 else 0 |
| 290 | + return jc.DefaultLinearAxis(ax_type, jc.Double(scale), jc.Double(origin)) |
297 | 291 |
|
298 |
| -def _get_enumerated_axis(): |
299 |
| - """Get EnumeratedAxis. |
300 | 292 |
|
301 |
| - EnumeratedAxis is only in releases later than March 2020. If using |
302 |
| - an older version of ImageJ without EnumeratedAxis, use |
303 |
| - _get_linear_axis() instead. |
| 293 | +def _is_numeric_scale(coords_array: np.ndarray) -> bool: |
304 | 294 | """
|
305 |
| - return sj.jimport("net.imagej.axis.EnumeratedAxis") |
306 |
| - |
307 |
| - |
308 |
| -def _get_linear_axis(axis_type: "jc.AxisType", values): |
309 |
| - """Get linear axis. |
| 295 | + Checks if the coordinates array of the given axis is numeric. |
310 | 296 |
|
311 |
| - This is used if no EnumeratedAxis is found. If EnumeratedAxis |
312 |
| - is available, use _get_enumerated_axis() instead. |
| 297 | + :param coords_array: A 1D NumPy array. |
| 298 | + :return: bool |
313 | 299 | """
|
314 |
| - DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis") |
315 |
| - origin = values[0] |
316 |
| - scale = values[1] - values[0] |
317 |
| - axis = DefaultLinearAxis(axis_type, scale, origin) |
318 |
| - return axis |
| 300 | + return np.issubdtype(coords_array.dtype, np.number) |
319 | 301 |
|
320 | 302 |
|
321 | 303 | def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus":
|
|
0 commit comments