Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions kfac_jax/_src/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""K-FAC tracing functionality for functions needed for curvature estimation."""
from collections.abc import Iterable
import dataclasses
import functools
from typing import Any, Callable, Sequence, TypeVar, Generic
import itertools
from typing import Any, Callable, Generic, Self, Sequence, TypeVar

from absl import logging
import jax
Expand Down Expand Up @@ -739,6 +741,41 @@ def losses_sum(param_primals: Params) -> Array:
return hvp, processed_jaxpr.reconstruct_losses(losses_inputs)


@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class VarMap:
"""A mapping from jaxpr variables to values.

Variables in a jaxpr are not ordered, and thus ``dict[Var, ...]`` cannot be
passed to PyTree APIs. This class works around that by indexing the dict
on the ID of each variable instead of the variable itself.
"""
id_to_var: dict[int, Var] = dataclasses.field(
default_factory=dict, metadata=dict(static=True)
)
var_to_val: dict[int, Any] = dataclasses.field(default_factory=dict)

def __contains__(self, var: Var) -> bool:
return id(var) in self.id_to_var

def __getitem__(self, var: Var) -> Any:
return self.var_to_val[id(var)]

def get(self, var: Var) -> Any:
return self.var_to_val[id(var)]

def update(self, it: Iterable[tuple[Var, Any]]) -> None:
for var, val in it:
self.id_to_var[id(var)] = var
self.var_to_val[id(var)] = val

@classmethod
def create(cls, it: Iterable[tuple[Var, Any]]) -> Self:
self = cls()
self.update(it)
return self


def _layer_tag_vjp(
processed_jaxpr: ProcessedJaxpr,
primal_func_args: FuncArgs,
Expand Down Expand Up @@ -867,21 +904,24 @@ def write(variables: list[jex.core.Var], values: list[Array]) -> None:

# First compute the primal values for the inputs to all layer tags
layer_input_values = forward()
primals_dict = dict(zip(layer_input_vars, layer_input_values))
primals = zip(layer_input_vars, layer_input_values)

# Update with the values of all parameters, which are inputs to the function
primals_dict.update(
primals = itertools.chain(
primals,
zip(
processed_jaxpr.jaxpr.invars,
jax.tree_util.tree_leaves(primal_func_args),
)
),
)

primals_dict = VarMap.create(primals)

# Create auxiliary values all equal to zero.
aux_values = jax.tree_util.tree_map(jnp.zeros_like, layer_input_values)

# Create a mapping from all layer tag inputs to the zero values
aux_dict = dict(zip(layer_input_vars, aux_values))
aux_dict = VarMap.create(zip(layer_input_vars, aux_values))

# These values would now allow us to compute gradients wrt the layer tags
# inputs, which are intermediate expressions in the Jaxpr.
Expand Down