diff --git a/kfac_jax/_src/tracer.py b/kfac_jax/_src/tracer.py index 776dbf9..eeb7e22 100644 --- a/kfac_jax/_src/tracer.py +++ b/kfac_jax/_src/tracer.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """K-FAC tracing functionality for functions needed for curvature estimation.""" +from collections.abc import Iterable import dataclasses import functools -from typing import Any, Callable, Sequence, TypeVar, Generic +import itertools +from typing import Any, Callable, Generic, Self, Sequence, TypeVar from absl import logging import jax @@ -739,6 +741,41 @@ def losses_sum(param_primals: Params) -> Array: return hvp, processed_jaxpr.reconstruct_losses(losses_inputs) +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class VarMap: + """A mapping from jaxpr variables to values. + + Variables in a jaxpr are not ordered, and thus ``dict[Var, ...]`` cannot be + passed to PyTree APIs. This class works around that by indexing the dict + on the ID of each variable instead of the variable itself. + """ + id_to_var: dict[int, Var] = dataclasses.field( + default_factory=dict, metadata=dict(static=True) + ) + var_to_val: dict[int, Any] = dataclasses.field(default_factory=dict) + + def __contains__(self, var: Var) -> bool: + return id(var) in self.id_to_var + + def __getitem__(self, var: Var) -> Any: + return self.var_to_val[id(var)] + + def get(self, var: Var) -> Any: + return self.var_to_val[id(var)] + + def update(self, it: Iterable[tuple[Var, Any]]) -> None: + for var, val in it: + self.id_to_var[id(var)] = var + self.var_to_val[id(var)] = val + + @classmethod + def create(cls, it: Iterable[tuple[Var, Any]]) -> Self: + self = cls() + self.update(it) + return self + + def _layer_tag_vjp( processed_jaxpr: ProcessedJaxpr, primal_func_args: FuncArgs, @@ -867,21 +904,24 @@ def write(variables: list[jex.core.Var], values: list[Array]) -> None: # First compute the primal values for the inputs to all layer tags layer_input_values = forward() - primals_dict = dict(zip(layer_input_vars, layer_input_values)) + primals = zip(layer_input_vars, layer_input_values) # Update with the values of all parameters, which are inputs to the function - primals_dict.update( + primals = itertools.chain( + primals, zip( processed_jaxpr.jaxpr.invars, jax.tree_util.tree_leaves(primal_func_args), - ) + ), ) + primals_dict = VarMap.create(primals) + # Create auxiliary values all equal to zero. aux_values = jax.tree_util.tree_map(jnp.zeros_like, layer_input_values) # Create a mapping from all layer tag inputs to the zero values - aux_dict = dict(zip(layer_input_vars, aux_values)) + aux_dict = VarMap.create(zip(layer_input_vars, aux_values)) # These values would now allow us to compute gradients wrt the layer tags # inputs, which are intermediate expressions in the Jaxpr.