Skip to content

Commit ad307b2

Browse files
Merge pull request #556 from michael-0brien/padded-shape
Deprecate `BasicImageConfig(..., pad_options=...)` -> `BasicImageConfig(..., padded_shape=...)`
2 parents a607bdf + 7420c46 commit ad307b2

15 files changed

Lines changed: 78 additions & 103 deletions

benchmarks/benchmark_jit.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,11 @@ def make_particle_parameters(key: PRNGKeyArray, config: cxs.BasicImageConfig):
6767
}
6868

6969
# Generate particle parameters. First, the image config
70-
pad_options = dict(shape=(200, 200))
71-
7270
config = cxs.BasicImageConfig(
7371
shape=(150, 150),
7472
pixel_size=2.0,
7573
voltage_in_kilovolts=300.0,
76-
pad_options=pad_options,
74+
padded_shape=(200, 200),
7775
)
7876
# ... RNG keys
7977
keys = jax.random.split(jax.random.key(0), num_images)

benchmarks/benchmark_projection_method_tradeoff.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ def setup_volumes_and_configs(n_iterations, n_atoms, box_size, pixel_size=2.0):
6666
)
6767

6868
# Image config
69-
pad_options = dict(shape=(int(box_size * 1.5), int(box_size * 1.5)))
69+
padded_shape = (int(box_size * 1.5), int(box_size * 1.5))
7070
config = cxs.BasicImageConfig(
7171
shape=(box_size, box_size),
7272
pixel_size=pixel_size,
7373
voltage_in_kilovolts=300.0,
74-
pad_options=pad_options,
74+
padded_shape=padded_shape,
7575
)
7676

7777
return volume_fourier_grid, avg_time, volume_gmm, atom_volume, config

cryojax/dataset/_particle_data/relion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1148,9 +1148,10 @@ def _make_config(
11481148
voltage_in_kilovolts,
11491149
pad_options,
11501150
):
1151+
padded_shape = None if "shape" not in pad_options else pad_options["shape"]
11511152
return eqx.tree_at(
11521153
lambda x: (x.pixel_size, x.voltage_in_kilovolts),
1153-
BasicImageConfig(image_shape, 1.0, 1.0, pad_options=pad_options),
1154+
BasicImageConfig(image_shape, 1.0, 1.0, padded_shape=padded_shape),
11541155
(pixel_size, voltage_in_kilovolts),
11551156
)
11561157

cryojax/io/_pdb.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, Literal, TypedDict, overload
1212
from xml.etree import ElementTree
1313

14+
import equinox.internal as eqxi
1415
import jax
1516
import mdtraj
1617
import mmdf
@@ -23,10 +24,10 @@
2324

2425

2526
if hasattr(typing, "GENERATING_DOCUMENTATION"):
26-
AtomProperties = dict # pyright: ignore[reportAssignmentType]
27+
_AtomProperties = dict[str, Any] # pyright: ignore[reportAssignmentType]
2728
else:
2829

29-
class AtomProperties(TypedDict):
30+
class _AtomProperties(TypedDict):
3031
masses: Float[np.ndarray, "... n_atoms"]
3132
b_factors: Float[np.ndarray, "... n_atoms"]
3233
charges: Float[np.ndarray, "... n_atoms"]
@@ -62,7 +63,7 @@ def read_atoms_from_pdb( # type: ignore
6263
) -> tuple[
6364
Float[np.ndarray, "... n_atoms 3"],
6465
Int[np.ndarray, " n_atoms"],
65-
AtomProperties,
66+
dict[str, Any],
6667
]: ...
6768

6869

@@ -81,6 +82,7 @@ def read_atoms_from_pdb(
8182
) -> tuple[Float[np.ndarray, "... n_atoms 3"], Int[np.ndarray, "... n_atoms"]]: ...
8283

8384

85+
@eqxi.doc_remove_args("loads_b_factors")
8486
def read_atoms_from_pdb(
8587
filename: str | pathlib.Path,
8688
*,
@@ -203,7 +205,7 @@ def mmdf_to_atoms( # type: ignore
203205
) -> tuple[
204206
Float[np.ndarray, "... n_atoms 3"],
205207
Int[np.ndarray, "... n_atoms"],
206-
AtomProperties,
208+
dict[str, Any],
207209
]: ...
208210

209211

@@ -408,7 +410,7 @@ def read_topology_from_pdb(
408410
class _AtomicModelInfo(TypedDict):
409411
positions: Float[np.ndarray, "M N 3"]
410412
numbers: Int[np.ndarray, "M N 3"]
411-
properties: AtomProperties
413+
properties: _AtomProperties
412414

413415

414416
def _load_atom_info(df: pd.DataFrame, model_index: int | None, stack_models: bool):
@@ -450,7 +452,7 @@ def _load_atom_info(df: pd.DataFrame, model_index: int | None, stack_models: boo
450452
)
451453

452454
# Gather atom info and return
453-
properties = AtomProperties(
455+
properties = _AtomProperties(
454456
charges=np.asarray(charges, dtype=int),
455457
b_factors=np.asarray(b_factors, dtype=float),
456458
masses=np.asarray(atom_masses, dtype=float),

cryojax/simulator/_image_config.py

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import math
44
import warnings
55
from functools import cached_property
6-
from typing import Any, Literal, TypedDict
6+
from typing import Literal, TypedDict
77

88
import equinox as eqx
9+
import equinox.internal as eqxi
910
import jax
1011
import jax.numpy as jnp
1112
from jaxtyping import Array, Float
@@ -29,8 +30,8 @@
2930

3031
# Not currently public API
3132
class PrecomputedGrids(eqx.Module, strict=True):
32-
only_rfft: bool
33-
only_fourier: bool
33+
only_rfft: bool = eqx.field(static=True)
34+
only_fourier: bool = eqx.field(static=True)
3435

3536
_frequency_grid: Float[Array, "_ _ 2"]
3637
_coordinate_grid: Float[Array, "_ _ 2"] | None
@@ -138,8 +139,7 @@ class AbstractImageConfig(eqx.Module, strict=True):
138139
pixel_size: eqx.AbstractVar[Float[Array, ""]]
139140
voltage_in_kilovolts: eqx.AbstractVar[Float[Array, ""]]
140141

141-
pad_options: eqx.AbstractVar[dict[str, Any]]
142-
142+
padded_shape: eqx.AbstractVar[tuple[int, int]]
143143
precompute_mode: eqx.AbstractVar[
144144
Literal["none", "rfft", "fft", "all", "compile_time_eval"]
145145
]
@@ -160,12 +160,6 @@ def __check_init__(self):
160160
f"Found that `{cls}.padded_shape` is less than `{cls}.shape` in one or "
161161
" more dimensions."
162162
)
163-
if set(self.pad_options.keys()) != {"shape"}:
164-
raise AttributeError(
165-
f"Found that `{cls}.pad_options` was not a dictionary with "
166-
"key 'shape'. "
167-
f"Instead, had keys {set(self.pad_options.keys())}."
168-
)
169163

170164
@property
171165
def wavelength_in_angstroms(self) -> Float[Array, ""]:
@@ -278,10 +272,6 @@ def _get_grid_impl(_self):
278272

279273
return frequency_grid
280274

281-
@property
282-
def padded_shape(self):
283-
return self.pad_options["shape"]
284-
285275
@property
286276
def n_pixels(self) -> int:
287277
"""Convenience property for `math.prod(shape)`"""
@@ -544,21 +534,24 @@ class BasicImageConfig(AbstractImageConfig, strict=True):
544534
pixel_size: Float[Array, ""]
545535
voltage_in_kilovolts: Float[Array, ""]
546536

547-
pad_options: PadOptions
548-
549-
precompute_mode: Literal["none", "rfft", "fft", "all", "compile_time_eval"]
537+
padded_shape: tuple[int, int]
538+
precompute_mode: Literal["none", "rfft", "fft", "all", "compile_time_eval"] = (
539+
eqx.field(static=True)
540+
)
550541
precomputed_grids: PrecomputedGrids | None
551542

543+
@eqxi.doc_remove_args("pad_options")
552544
def __init__(
553545
self,
554546
shape: tuple[int, int],
555547
pixel_size: FloatLike,
556548
voltage_in_kilovolts: FloatLike,
557549
*,
550+
padded_shape: tuple[int, int] | None = None,
558551
precompute_mode: Literal[
559552
"none", "rfft", "fft", "all", "compile_time_eval"
560553
] = "none",
561-
pad_options: dict[str, Any] = {},
554+
pad_options: dict = {},
562555
):
563556
"""**Arguments:**
564557
@@ -568,11 +561,9 @@ def __init__(
568561
The pixel size of the image in angstroms.
569562
- `voltage_in_kilovolts`:
570563
The incident energy of the electron beam.
571-
- `pad_options`:
572-
Options that control image padding.
573-
- 'shape':
574-
The shape of the image after padding. By default, equal
575-
to `shape`.
564+
- `padded_shape`:
565+
The shape of the image after padding. By default, equal
566+
to `shape`.
576567
- `precompute_mode`:
577568
How to pre-compute coordinate and frequency grids stored in
578569
the `image_config`. Options are
@@ -597,10 +588,17 @@ def __init__(
597588
# Set parameters
598589
self.pixel_size = jnp.asarray(pixel_size, dtype=float)
599590
self.voltage_in_kilovolts = jnp.asarray(voltage_in_kilovolts, dtype=float)
600-
# Set shape
591+
# Set shape and padded shape
592+
if "shape" in pad_options:
593+
warnings.warn(
594+
"`BasicImageConfig(..., pad_options=...)` is deprecated and will "
595+
"be removed in cryoJAX 0.6.0. Use `padded_shape` instead.",
596+
category=FutureWarning,
597+
stacklevel=2,
598+
)
599+
padded_shape = pad_options["shape"]
601600
self.shape = shape
602-
# Set pad options
603-
self.pad_options = _dict_to_pad_options(pad_options, shape)
601+
self.padded_shape = shape if padded_shape is None else padded_shape
604602
# Finally, grid precompute
605603
if precompute_mode == "rfft":
606604
self.precomputed_grids = PrecomputedGrids(
@@ -628,22 +626,25 @@ class DoseImageConfig(AbstractImageConfig, strict=True):
628626
voltage_in_kilovolts: Float[Array, ""]
629627
electron_dose: Float[Array, ""]
630628

631-
pad_options: PadOptions
632-
633-
precompute_mode: Literal["none", "rfft", "fft", "all", "compile_time_eval"]
629+
padded_shape: tuple[int, int]
630+
precompute_mode: Literal["none", "rfft", "fft", "all", "compile_time_eval"] = (
631+
eqx.field(static=True)
632+
)
634633
precomputed_grids: PrecomputedGrids | None
635634

635+
@eqxi.doc_remove_args("pad_options")
636636
def __init__(
637637
self,
638638
shape: tuple[int, int],
639639
pixel_size: FloatLike,
640640
voltage_in_kilovolts: FloatLike,
641641
electron_dose: FloatLike,
642642
*,
643+
padded_shape: tuple[int, int] | None = None,
643644
precompute_mode: Literal[
644645
"none", "rfft", "fft", "all", "compile_time_eval"
645646
] = "none",
646-
pad_options: dict[str, Any] = {},
647+
pad_options: dict = {},
647648
):
648649
"""**Arguments:**
649650
@@ -656,12 +657,9 @@ def __init__(
656657
- `electron_dose`:
657658
The integrated dose rate of the electron beam in
658659
$e^-/A^2$
659-
- `pad_options`:
660-
Options that control image padding. This is a dictionary
661-
with keys:
662-
- 'shape':
663-
The shape of the image after padding. By default, equal
664-
to `shape`.
660+
- `padded_shape`:
661+
The shape of the image after padding. By default, equal
662+
to `shape`.
665663
- `precompute_mode`:
666664
How to pre-compute coordinate and frequency grids stored in
667665
the `image_config`. Options are
@@ -687,10 +685,17 @@ def __init__(
687685
self.pixel_size = jnp.asarray(pixel_size, dtype=float)
688686
self.voltage_in_kilovolts = jnp.asarray(voltage_in_kilovolts, dtype=float)
689687
self.electron_dose = jnp.asarray(electron_dose, dtype=float)
690-
# Set shape
688+
# Set shape and padded shape
689+
if "shape" in pad_options:
690+
warnings.warn(
691+
"`BasicImageConfig(..., pad_options=...)` is deprecated and will "
692+
"be removed in cryoJAX 0.6.0. Use `padded_shape` instead.",
693+
category=FutureWarning,
694+
stacklevel=2,
695+
)
696+
padded_shape = pad_options["shape"]
691697
self.shape = shape
692-
# Set pad options
693-
self.pad_options = _dict_to_pad_options(pad_options, shape)
698+
self.padded_shape = shape if padded_shape is None else padded_shape
694699
# Finally, grid precompute
695700
if precompute_mode == "rfft":
696701
self.precomputed_grids = PrecomputedGrids(
@@ -728,16 +733,3 @@ def _safe_multiply_by_constant(
728733
grid = grid.at[:, 1:, 0].multiply(constant)
729734
grid = grid.at[1:, :, 1].multiply(constant)
730735
return grid
731-
732-
733-
def _dict_to_pad_options(d: dict[str, Any], default_shape: tuple[int, int]) -> PadOptions:
734-
if not set(d.keys()).issubset({"shape"}):
735-
raise ValueError(
736-
"Expected that dictionary `pad_options` passed to "
737-
"`BasicImageConfig(..., pad_options=...)` "
738-
f"had a subset of keys {{'shape'}}, but found that it had keys "
739-
f"{set(d.keys())}."
740-
)
741-
shape = d["shape"] if "shape" in d else default_shape
742-
743-
return PadOptions(shape=shape)

docs/api/simulator/config.md

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,14 @@ The `AbstractImageConfig` is an object at the core of simulating images in `cryo
88
- shape
99
- pixel_size
1010
- voltage_in_kilovolts
11-
- pad_options
11+
- padded_shape
1212

1313
---
1414

1515
::: cryojax.simulator.BasicImageConfig
1616
options:
1717
members:
1818
- __init__
19-
- shape
20-
- pixel_size
21-
- voltage_in_kilovolts
2219
- wavelength_in_angstroms
2320
- lorentz_factor
2421
- interaction_constant
@@ -27,7 +24,6 @@ The `AbstractImageConfig` is an object at the core of simulating images in `cryo
2724
- n_pixels
2825
- y_dim
2926
- x_dim
30-
- padded_shape
3127
- padded_n_pixels
3228
- padded_y_dim
3329
- padded_x_dim
@@ -38,10 +34,6 @@ The `AbstractImageConfig` is an object at the core of simulating images in `cryo
3834
options:
3935
members:
4036
- __init__
41-
- shape
42-
- pixel_size
43-
- voltage_in_kilovolts
44-
- electron_dose
4537
- wavelength_in_angstroms
4638
- lorentz_factor
4739
- interaction_constant
@@ -50,7 +42,6 @@ The `AbstractImageConfig` is an object at the core of simulating images in `cryo
5042
- n_pixels
5143
- y_dim
5244
- x_dim
53-
- padded_shape
5445
- padded_n_pixels
5546
- padded_y_dim
5647
- padded_x_dim

docs/examples/extending-cryojax.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,11 @@
8686
" psi_angle=-5.0,\n",
8787
")\n",
8888
"# ... the configuration. Add padding with respect to the final image shape.\n",
89-
"pad_options = dict(shape=volume.shape[0:2])\n",
9089
"config = cxs.BasicImageConfig(\n",
9190
" shape=(80, 80),\n",
9291
" pixel_size=voxel_size,\n",
9392
" voltage_in_kilovolts=300.0,\n",
94-
" pad_options=pad_options,\n",
93+
" padded_shape=volume.shape[0:2],\n",
9594
")\n",
9695
"\n",
9796
"\n",

docs/examples/simulate-image.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
},
8181
{
8282
"cell_type": "code",
83-
"execution_count": 5,
83+
"execution_count": null,
8484
"metadata": {},
8585
"outputs": [],
8686
"source": [
@@ -100,12 +100,11 @@
100100
")\n",
101101
"transfer_theory = cxs.ContrastTransferTheory(ctf, amplitude_contrast_ratio=0.1)\n",
102102
"# Then the configuration. Add padding to avoid periodic artifacts due to CTF rings\n",
103-
"pad_options = dict(shape=volume.shape[0:2])\n",
104103
"image_config = cxs.BasicImageConfig(\n",
105104
" shape=(80, 80),\n",
106105
" pixel_size=voxel_size,\n",
107106
" voltage_in_kilovolts=300.0,\n",
108-
" pad_options=pad_options,\n",
107+
" padded_shape=volume.shape[0:2],\n",
109108
")"
110109
]
111110
},

0 commit comments

Comments
 (0)