Skip to content

Commit 56ca272

Browse files
committed
Refactor metadata module to be more pythonic
The axis submodule has also been streamlined to drop the CalibratedAxis dict and instead uses a list of Strings. This avoids Java import errors if a user imports the axis submodule before initializing ImageJ.
1 parent 61a814c commit 56ca272

File tree

5 files changed

+87
-118
lines changed

5 files changed

+87
-118
lines changed

src/imagej/convert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def java_to_xarray(ij: "jc.ImageJ", jobj) -> xr.DataArray:
233233
xr_dims = list(permuted_rai.dims)
234234
xr_attrs = sj.to_python(permuted_rai.getProperties())
235235
xr_attrs = {sj.to_python(k): sj.to_python(v) for k, v in xr_attrs.items()}
236-
xr_attrs["imagej"] = metadata.ImageMetadata.create_imagej_metadata(xr_axes, xr_dims)
236+
xr_attrs["imagej"] = metadata.create_imagej_metadata(xr_axes, xr_dims)
237237
# reverse axes and dims to match narr
238238
xr_axes.reverse()
239239
xr_dims.reverse()

src/imagej/dims.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,12 @@ def _assign_axes(
216216
if cal_axis_type == "DefaultLinearAxis":
217217
origin = xarr.attrs["imagej"][ij_dim + "_origin"]
218218
scale = xarr.attrs["imagej"][ij_dim + "_scale"]
219-
jaxis = metadata.Axis._str_to_cal_axis(cal_axis_type)(
219+
jaxis = metadata.axis.str_to_calibrated_axis(cal_axis_type)(
220220
ax_type, scale, origin
221221
)
222222
else:
223223
try:
224-
jaxis = metadata.Axis._str_to_cal_axis(cal_axis_type)(
224+
jaxis = metadata.axis.str_to_calibrated_axis(cal_axis_type)(
225225
ax_type, doub_coords
226226
)
227227
except (JException, TypeError):

src/imagej/metadata.py

-115
This file was deleted.

src/imagej/metadata/__init__.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from typing import Sequence
2+
3+
import imagej.dims as dims
4+
import imagej.metadata.axis as axis
5+
from imagej._java import jc
6+
7+
8+
def create_imagej_metadata(
9+
axes: Sequence["jc.CalibratedAxis"], dim_seq: Sequence[str]
10+
) -> dict:
11+
"""
12+
Create the ImageJ metadata attribute dictionary for xarray's global attributes.
13+
:param axes: A list or tuple of ImageJ2 axis objects
14+
(e.g. net.imagej.axis.DefaultLinearAxis).
15+
:param dim_seq: A list or tuple of the dimension order (e.g. ['X', 'Y', 'C']).
16+
:return: Dict of image metadata.
17+
"""
18+
ij_metadata = {}
19+
if len(axes) != len(dim_seq):
20+
raise ValueError(
21+
f"Axes length ({len(axes)}) does not match \
22+
dimension length ({len(dim_seq)})."
23+
)
24+
25+
for i in range(len(axes)):
26+
# get CalibratedAxis type as string (e.g. "EnumeratedAxis")
27+
ij_metadata[
28+
dims._to_ijdim(dim_seq[i]) + "_cal_axis_type"
29+
] = axis.calibrated_axis_to_str(axes[i])
30+
# get scale and origin for DefaultLinearAxis
31+
if isinstance(axes[i], jc.DefaultLinearAxis):
32+
ij_metadata[dims._to_ijdim(dim_seq[i]) + "_scale"] = float(axes[i].scale())
33+
ij_metadata[dims._to_ijdim(dim_seq[i]) + "_origin"] = float(
34+
axes[i].origin()
35+
)
36+
37+
return ij_metadata

src/imagej/metadata/axis.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from _jpype import JClass
2+
3+
from imagej._java import jc
4+
5+
_calibrated_axes = [
6+
"net.imagej.axis.ChapmanRichardsAxis",
7+
"net.imagej.axis.DefaultLinearAxis",
8+
"net.imagej.axis.EnumeratedAxis",
9+
"net.imagej.axis.ExponentialAxis",
10+
"net.imagej.axis.ExponentialRecoveryAxis",
11+
"net.imagej.axis.GammaVariateAxis",
12+
"net.imagej.axis.GaussianAxis",
13+
"net.imagej.axis.IdentityAxis",
14+
"net.imagej.axis.InverseRodbardAxis",
15+
"net.imagej.axis.LogLinearAxis",
16+
"net.imagej.axis.PolynomialAxis",
17+
"net.imagej.axis.PowerAxis",
18+
"net.imagej.axis.RodbardAxis",
19+
]
20+
21+
22+
def calibrated_axis_to_str(axis: "jc.CalibratedAxis") -> str:
23+
"""
24+
Convert a CalibratedAxis class to a String.
25+
:param axis: CalibratedAxis type (e.g. net.imagej.axis.DefaultLinearAxis).
26+
:return: String of CalibratedAxis typeb(e.g. "DefaultLinearAxis").
27+
"""
28+
if not isinstance(axis, JClass):
29+
axis = axis.__class__
30+
31+
return str(axis).split("'")[1]
32+
33+
34+
def str_to_calibrated_axis(axis: str) -> "jc.CalibratedAxis":
35+
"""
36+
Convert a String to CalibratedAxis class.
37+
:param axis: String of calibratedAxis type (e.g. "DefaultLinearAxis").
38+
:return: Java class of CalibratedAxis type
39+
(e.g. net.imagej.axis.DefaultLinearAxis).
40+
"""
41+
if not isinstance(axis, str):
42+
raise TypeError(f"Axis {type(axis)} is not a String.")
43+
44+
if axis in _calibrated_axes:
45+
return getattr(jc, axis.split(".")[3])
46+
else:
47+
return None

0 commit comments

Comments
 (0)