Skip to content

Commit

Permalink
Refactor metadata module to be more pythonic
Browse files Browse the repository at this point in the history
The axis submodule has also been streamlined to drop
the CalibratedAxis dict and instead uses a list of Strings.
This avoids Java import errors if a user imports the axis submodule
before initializing ImageJ.
  • Loading branch information
elevans committed Apr 14, 2023
1 parent 61a814c commit 56ca272
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 118 deletions.
2 changes: 1 addition & 1 deletion src/imagej/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def java_to_xarray(ij: "jc.ImageJ", jobj) -> xr.DataArray:
xr_dims = list(permuted_rai.dims)
xr_attrs = sj.to_python(permuted_rai.getProperties())
xr_attrs = {sj.to_python(k): sj.to_python(v) for k, v in xr_attrs.items()}
xr_attrs["imagej"] = metadata.ImageMetadata.create_imagej_metadata(xr_axes, xr_dims)
xr_attrs["imagej"] = metadata.create_imagej_metadata(xr_axes, xr_dims)
# reverse axes and dims to match narr
xr_axes.reverse()
xr_dims.reverse()
Expand Down
4 changes: 2 additions & 2 deletions src/imagej/dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ def _assign_axes(
if cal_axis_type == "DefaultLinearAxis":
origin = xarr.attrs["imagej"][ij_dim + "_origin"]
scale = xarr.attrs["imagej"][ij_dim + "_scale"]
jaxis = metadata.Axis._str_to_cal_axis(cal_axis_type)(
jaxis = metadata.axis.str_to_calibrated_axis(cal_axis_type)(
ax_type, scale, origin
)
else:
try:
jaxis = metadata.Axis._str_to_cal_axis(cal_axis_type)(
jaxis = metadata.axis.str_to_calibrated_axis(cal_axis_type)(
ax_type, doub_coords
)
except (JException, TypeError):
Expand Down
115 changes: 0 additions & 115 deletions src/imagej/metadata.py

This file was deleted.

37 changes: 37 additions & 0 deletions src/imagej/metadata/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Sequence

import imagej.dims as dims
import imagej.metadata.axis as axis
from imagej._java import jc


def create_imagej_metadata(
axes: Sequence["jc.CalibratedAxis"], dim_seq: Sequence[str]
) -> dict:
"""
Create the ImageJ metadata attribute dictionary for xarray's global attributes.
:param axes: A list or tuple of ImageJ2 axis objects
(e.g. net.imagej.axis.DefaultLinearAxis).
:param dim_seq: A list or tuple of the dimension order (e.g. ['X', 'Y', 'C']).
:return: Dict of image metadata.
"""
ij_metadata = {}
if len(axes) != len(dim_seq):
raise ValueError(
f"Axes length ({len(axes)}) does not match \
dimension length ({len(dim_seq)})."
)

for i in range(len(axes)):
# get CalibratedAxis type as string (e.g. "EnumeratedAxis")
ij_metadata[
dims._to_ijdim(dim_seq[i]) + "_cal_axis_type"
] = axis.calibrated_axis_to_str(axes[i])
# get scale and origin for DefaultLinearAxis
if isinstance(axes[i], jc.DefaultLinearAxis):
ij_metadata[dims._to_ijdim(dim_seq[i]) + "_scale"] = float(axes[i].scale())
ij_metadata[dims._to_ijdim(dim_seq[i]) + "_origin"] = float(
axes[i].origin()
)

return ij_metadata
47 changes: 47 additions & 0 deletions src/imagej/metadata/axis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from _jpype import JClass

from imagej._java import jc

_calibrated_axes = [
"net.imagej.axis.ChapmanRichardsAxis",
"net.imagej.axis.DefaultLinearAxis",
"net.imagej.axis.EnumeratedAxis",
"net.imagej.axis.ExponentialAxis",
"net.imagej.axis.ExponentialRecoveryAxis",
"net.imagej.axis.GammaVariateAxis",
"net.imagej.axis.GaussianAxis",
"net.imagej.axis.IdentityAxis",
"net.imagej.axis.InverseRodbardAxis",
"net.imagej.axis.LogLinearAxis",
"net.imagej.axis.PolynomialAxis",
"net.imagej.axis.PowerAxis",
"net.imagej.axis.RodbardAxis",
]


def calibrated_axis_to_str(axis: "jc.CalibratedAxis") -> str:
"""
Convert a CalibratedAxis class to a String.
:param axis: CalibratedAxis type (e.g. net.imagej.axis.DefaultLinearAxis).
:return: String of CalibratedAxis typeb(e.g. "DefaultLinearAxis").
"""
if not isinstance(axis, JClass):
axis = axis.__class__

return str(axis).split("'")[1]


def str_to_calibrated_axis(axis: str) -> "jc.CalibratedAxis":
"""
Convert a String to CalibratedAxis class.
:param axis: String of calibratedAxis type (e.g. "DefaultLinearAxis").
:return: Java class of CalibratedAxis type
(e.g. net.imagej.axis.DefaultLinearAxis).
"""
if not isinstance(axis, str):
raise TypeError(f"Axis {type(axis)} is not a String.")

if axis in _calibrated_axes:
return getattr(jc, axis.split(".")[3])
else:
return None

0 comments on commit 56ca272

Please sign in to comment.