Skip to content

Commit

Permalink
Allow StateIndex to be passed dynamically
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 13, 2024
1 parent 97ac55a commit 59d21d9
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions equinox/nn/_stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@
_T = TypeVar("_T")


class _Sentinel(Module):
"""A module for sentinels that can be passed dynamically."""

pass


# Used as a sentinel in two ways: keeping track of updated `State`s, and keeping track
# of deleted initial states.
_sentinel = _Sentinel()


class StateIndex(Module, Generic[_Value], strict=True):
"""This wraps together (a) a unique dictionary key used for looking up a stateful
value, and (b) how that stateful value should be initialised.
Expand All @@ -43,10 +54,10 @@ def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
[`equinox.nn.BatchNorm`][] for further reference.
""" # noqa: E501

# Starts off as an `object` when initialised; later replaced with an `int` inside
# Starts off as None when initialised; later replaced with an `int` inside
# `make_with_state`.
marker: Union[object, int] = field(static=True)
init: _Value
init: Union[_Value, _Sentinel]

def __init__(self, init: _Value):
"""**Arguments:**
Expand All @@ -70,11 +81,6 @@ def _is_index(x: Any) -> bool:
return isinstance(x, StateIndex)


# Used as a sentinel in two ways: keeping track of updated `State`s, and keeping track
# of deleted initial states.
_sentinel = object()


_state_error = """
Attempted to use old state. Probably you have done something like:
```
Expand Down Expand Up @@ -117,14 +123,14 @@ def __init__(self, model: PyTree):
leaves = jtu.tree_leaves(model, is_leaf=_is_index)
for leaf in leaves:
if _is_index(leaf):
if leaf.init is _sentinel:
if isinstance(leaf.init, _Sentinel):
raise ValueError(
"Cannot call `eqx.nn.State(eqx.nn.delete_init_state(model))`. "
"You should call `eqx.nn.State(model)`, using the original "
"model."
)
state[leaf.marker] = jtu.tree_map(jnp.asarray, leaf.init)
self._state = state
self._state: Union[_Sentinel, dict[object | int, Any]] = state

def get(self, item: StateIndex[_Value]) -> _Value:
"""Given an [`equinox.nn.StateIndex`][], returns the value of its state.
Expand All @@ -137,11 +143,11 @@ def get(self, item: StateIndex[_Value]) -> _Value:
The current state associated with that index.
"""
if self._state is _sentinel:
if isinstance(self._state, _Sentinel):
raise ValueError(_state_error)
if type(item) is not StateIndex:
raise ValueError("Can only use `eqx.nn.StateIndex`s as state keys.")
return self._state[item.marker] # pyright: ignore
return self._state[item.marker]

def set(self, item: StateIndex[_Value], value: _Value) -> "State":
"""Sets a new value for an [`equinox.nn.StateIndex`][], **and returns the
Expand All @@ -159,11 +165,11 @@ def set(self, item: StateIndex[_Value], value: _Value) -> "State":
As a safety guard against accidentally writing `state.set(item, value)` without
assigning it to a new value, then the old object (`self`) will become invalid.
"""
if self._state is _sentinel:
if isinstance(self._state, _Sentinel):
raise ValueError(_state_error)
if type(item) is not StateIndex:
raise ValueError("Can only use `eqx.nn.StateIndex`s as state keys.")
old_value = self._state[item.marker] # pyright: ignore
old_value = self._state[item.marker]
value = jtu.tree_map(jnp.asarray, value)
old_struct = jax.eval_shape(lambda: old_value)
new_struct = jax.eval_shape(lambda: value)
Expand Down Expand Up @@ -195,7 +201,7 @@ def substate(self, pytree: PyTree) -> "State":
A new [`equinox.nn.State`][] object, which tracks only some of the overall
states.
"""
if self._state is _sentinel:
if isinstance(self._state, _Sentinel):
raise ValueError(_state_error)
leaves = jtu.tree_leaves(pytree, is_leaf=_is_index)
markers = [x.marker for x in leaves if _is_index(x)]
Expand All @@ -219,7 +225,7 @@ def update(self, substate: "State") -> "State":
As a safety guard against accidentally writing `state.set(item, value)` without
assigning it to a new value, then the old object (`self`) will become invalid.
"""
if self._state is _sentinel:
if isinstance(self._state, _Sentinel):
raise ValueError(_state_error)
if type(substate) is not State:
raise ValueError("Can only use `eqx.nn.State`s in `update`.")
Expand All @@ -240,7 +246,7 @@ def __repr__(self):
return tree_pformat(self)

def __tree_pp__(self, **kwargs):
if self._state is _sentinel:
if isinstance(self._state, _Sentinel):
return text("State(~old~)")
else:
objs = named_objs(
Expand All @@ -259,7 +265,7 @@ def __tree_pp__(self, **kwargs):
)

def tree_flatten(self):
if self._state is _sentinel:
if isinstance(self._state, _Sentinel):
raise ValueError(_state_error)
keys = tuple(self._state.keys()) # pyright: ignore
values = tuple(self._state[k] for k in keys) # pyright: ignore
Expand Down

0 comments on commit 59d21d9

Please sign in to comment.