Skip to content

Commit

Permalink
Make StateIndex a PyTree
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 11, 2024
1 parent 97ac55a commit 6ed3c57
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions equinox/nn/_stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
# Starts off as an `object` when initialised; later replaced with an `int` inside
# `make_with_state`.
marker: Union[object, int] = field(static=True)
init: _Value
init: Union[tuple[_Value], tuple[()]]

def __init__(self, init: _Value):
"""**Arguments:**
Expand All @@ -63,7 +63,7 @@ def __init__(self, init: _Value):
"initial state.)"
)
self.marker = object()
self.init = init
self.init = (init,)


def _is_index(x: Any) -> bool:
Expand Down Expand Up @@ -117,13 +117,13 @@ def __init__(self, model: PyTree):
leaves = jtu.tree_leaves(model, is_leaf=_is_index)
for leaf in leaves:
if _is_index(leaf):
if leaf.init is _sentinel:
if leaf.init == ():
raise ValueError(
"Cannot call `eqx.nn.State(eqx.nn.delete_init_state(model))`. "
"You should call `eqx.nn.State(model)`, using the original "
"model."
)
state[leaf.marker] = jtu.tree_map(jnp.asarray, leaf.init)
state[leaf.marker] = jtu.tree_map(jnp.asarray, leaf.init[0])
self._state = state

def get(self, item: StateIndex[_Value]) -> _Value:
Expand Down Expand Up @@ -278,7 +278,7 @@ def tree_unflatten(cls, keys, values):

def _delete_init_state(x):
if _is_index(x):
return tree_at(lambda y: y.init, x, _sentinel)
return tree_at(lambda y: y.init, x, ())
else:
return x

Expand Down Expand Up @@ -372,7 +372,7 @@ def make_with_state_impl(*args, **kwargs) -> tuple[_T, State]:
new_leaves = []
for leaf in leaves:
if _is_index(leaf):
leaf = StateIndex(leaf.init)
leaf = StateIndex(leaf.init[0])
object.__setattr__(leaf, "marker", counter)
counter += 1
new_leaves.append(leaf)
Expand Down

0 comments on commit 6ed3c57

Please sign in to comment.