33import math
44import warnings
55from functools import cached_property
6- from typing import Any , Literal , TypedDict
6+ from typing import Literal , TypedDict
77
88import equinox as eqx
9+ import equinox .internal as eqxi
910import jax
1011import jax .numpy as jnp
1112from jaxtyping import Array , Float
2930
3031# Not currently public API
3132class PrecomputedGrids (eqx .Module , strict = True ):
32- only_rfft : bool
33- only_fourier : bool
33+ only_rfft : bool = eqx . field ( static = True )
34+ only_fourier : bool = eqx . field ( static = True )
3435
3536 _frequency_grid : Float [Array , "_ _ 2" ]
3637 _coordinate_grid : Float [Array , "_ _ 2" ] | None
@@ -138,8 +139,7 @@ class AbstractImageConfig(eqx.Module, strict=True):
138139 pixel_size : eqx .AbstractVar [Float [Array , "" ]]
139140 voltage_in_kilovolts : eqx .AbstractVar [Float [Array , "" ]]
140141
141- pad_options : eqx .AbstractVar [dict [str , Any ]]
142-
142+ padded_shape : eqx .AbstractVar [tuple [int , int ]]
143143 precompute_mode : eqx .AbstractVar [
144144 Literal ["none" , "rfft" , "fft" , "all" , "compile_time_eval" ]
145145 ]
@@ -160,12 +160,6 @@ def __check_init__(self):
160160 f"Found that `{ cls } .padded_shape` is less than `{ cls } .shape` in one or "
161161 " more dimensions."
162162 )
163- if set (self .pad_options .keys ()) != {"shape" }:
164- raise AttributeError (
165- f"Found that `{ cls } .pad_options` was not a dictionary with "
166- "key 'shape'. "
167- f"Instead, had keys { set (self .pad_options .keys ())} ."
168- )
169163
170164 @property
171165 def wavelength_in_angstroms (self ) -> Float [Array , "" ]:
@@ -278,10 +272,6 @@ def _get_grid_impl(_self):
278272
279273 return frequency_grid
280274
281- @property
282- def padded_shape (self ):
283- return self .pad_options ["shape" ]
284-
285275 @property
286276 def n_pixels (self ) -> int :
287277 """Convenience property for `math.prod(shape)`"""
@@ -544,21 +534,24 @@ class BasicImageConfig(AbstractImageConfig, strict=True):
544534 pixel_size : Float [Array , "" ]
545535 voltage_in_kilovolts : Float [Array , "" ]
546536
547- pad_options : PadOptions
548-
549- precompute_mode : Literal ["none" , "rfft" , "fft" , "all" , "compile_time_eval" ]
537+ padded_shape : tuple [int , int ]
538+ precompute_mode : Literal ["none" , "rfft" , "fft" , "all" , "compile_time_eval" ] = (
539+ eqx .field (static = True )
540+ )
550541 precomputed_grids : PrecomputedGrids | None
551542
543+ @eqxi .doc_remove_args ("pad_options" )
552544 def __init__ (
553545 self ,
554546 shape : tuple [int , int ],
555547 pixel_size : FloatLike ,
556548 voltage_in_kilovolts : FloatLike ,
557549 * ,
550+ padded_shape : tuple [int , int ] | None = None ,
558551 precompute_mode : Literal [
559552 "none" , "rfft" , "fft" , "all" , "compile_time_eval"
560553 ] = "none" ,
561- pad_options : dict [ str , Any ] = {},
554+ pad_options : dict = {},
562555 ):
563556 """**Arguments:**
564557
@@ -568,11 +561,9 @@ def __init__(
568561 The pixel size of the image in angstroms.
569562 - `voltage_in_kilovolts`:
570563 The incident energy of the electron beam.
571- - `pad_options`:
572- Options that control image padding.
573- - 'shape':
574- The shape of the image after padding. By default, equal
575- to `shape`.
564+ - `padded_shape`:
565+ The shape of the image after padding. By default, equal
566+ to `shape`.
576567 - `precompute_mode`:
577568 How to pre-compute coordinate and frequency grids stored in
578569 the `image_config`. Options are
@@ -597,10 +588,17 @@ def __init__(
597588 # Set parameters
598589 self .pixel_size = jnp .asarray (pixel_size , dtype = float )
599590 self .voltage_in_kilovolts = jnp .asarray (voltage_in_kilovolts , dtype = float )
600- # Set shape
591+ # Set shape and padded shape
592+ if "shape" in pad_options :
593+ warnings .warn (
594+ "`BasicImageConfig(..., pad_options=...)` is deprecated and will "
595+ "be removed in cryoJAX 0.6.0. Use `padded_shape` instead." ,
596+ category = FutureWarning ,
597+ stacklevel = 2 ,
598+ )
599+ padded_shape = pad_options ["shape" ]
601600 self .shape = shape
602- # Set pad options
603- self .pad_options = _dict_to_pad_options (pad_options , shape )
601+ self .padded_shape = shape if padded_shape is None else padded_shape
604602 # Finally, grid precompute
605603 if precompute_mode == "rfft" :
606604 self .precomputed_grids = PrecomputedGrids (
@@ -628,22 +626,25 @@ class DoseImageConfig(AbstractImageConfig, strict=True):
628626 voltage_in_kilovolts : Float [Array , "" ]
629627 electron_dose : Float [Array , "" ]
630628
631- pad_options : PadOptions
632-
633- precompute_mode : Literal ["none" , "rfft" , "fft" , "all" , "compile_time_eval" ]
629+ padded_shape : tuple [int , int ]
630+ precompute_mode : Literal ["none" , "rfft" , "fft" , "all" , "compile_time_eval" ] = (
631+ eqx .field (static = True )
632+ )
634633 precomputed_grids : PrecomputedGrids | None
635634
635+ @eqxi .doc_remove_args ("pad_options" )
636636 def __init__ (
637637 self ,
638638 shape : tuple [int , int ],
639639 pixel_size : FloatLike ,
640640 voltage_in_kilovolts : FloatLike ,
641641 electron_dose : FloatLike ,
642642 * ,
643+ padded_shape : tuple [int , int ] | None = None ,
643644 precompute_mode : Literal [
644645 "none" , "rfft" , "fft" , "all" , "compile_time_eval"
645646 ] = "none" ,
646- pad_options : dict [ str , Any ] = {},
647+ pad_options : dict = {},
647648 ):
648649 """**Arguments:**
649650
@@ -656,12 +657,9 @@ def __init__(
656657 - `electron_dose`:
657658 The integrated dose rate of the electron beam in
658659 $e^-/A^2$
659- - `pad_options`:
660- Options that control image padding. This is a dictionary
661- with keys:
662- - 'shape':
663- The shape of the image after padding. By default, equal
664- to `shape`.
660+ - `padded_shape`:
661+ The shape of the image after padding. By default, equal
662+ to `shape`.
665663 - `precompute_mode`:
666664 How to pre-compute coordinate and frequency grids stored in
667665 the `image_config`. Options are
@@ -687,10 +685,17 @@ def __init__(
687685 self .pixel_size = jnp .asarray (pixel_size , dtype = float )
688686 self .voltage_in_kilovolts = jnp .asarray (voltage_in_kilovolts , dtype = float )
689687 self .electron_dose = jnp .asarray (electron_dose , dtype = float )
690- # Set shape
688+ # Set shape and padded shape
689+ if "shape" in pad_options :
690+ warnings .warn (
691+ "`BasicImageConfig(..., pad_options=...)` is deprecated and will "
692+ "be removed in cryoJAX 0.6.0. Use `padded_shape` instead." ,
693+ category = FutureWarning ,
694+ stacklevel = 2 ,
695+ )
696+ padded_shape = pad_options ["shape" ]
691697 self .shape = shape
692- # Set pad options
693- self .pad_options = _dict_to_pad_options (pad_options , shape )
698+ self .padded_shape = shape if padded_shape is None else padded_shape
694699 # Finally, grid precompute
695700 if precompute_mode == "rfft" :
696701 self .precomputed_grids = PrecomputedGrids (
@@ -728,16 +733,3 @@ def _safe_multiply_by_constant(
728733 grid = grid .at [:, 1 :, 0 ].multiply (constant )
729734 grid = grid .at [1 :, :, 1 ].multiply (constant )
730735 return grid
731-
732-
733- def _dict_to_pad_options (d : dict [str , Any ], default_shape : tuple [int , int ]) -> PadOptions :
734- if not set (d .keys ()).issubset ({"shape" }):
735- raise ValueError (
736- "Expected that dictionary `pad_options` passed to "
737- "`BasicImageConfig(..., pad_options=...)` "
738- f"had a subset of keys {{'shape'}}, but found that it had keys "
739- f"{ set (d .keys ())} ."
740- )
741- shape = d ["shape" ] if "shape" in d else default_shape
742-
743- return PadOptions (shape = shape )
0 commit comments