Skip to content

Commit a7f8af9

Browse files
committed
improve hijax guide
1 parent 697f4e5 commit a7f8af9

File tree

6 files changed

+485
-216
lines changed

6 files changed

+485
-216
lines changed

docs_nnx/hijax/hijax.ipynb

Lines changed: 90 additions & 21 deletions
Large diffs are not rendered by default.

docs_nnx/hijax/hijax.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,40 @@ jupytext:
88
jupytext_version: 1.13.8
99
---
1010

11-
# Hijax Variable
11+
# Hijax
1212

1313
```{code-cell} ipython3
1414
from flax import nnx
1515
import jax
1616
import jax.numpy as jnp
1717
import optax
1818
19-
current_mode = nnx.using_hijax()
19+
current_mode = nnx.using_hijax() # ignore: only needed for testing
2020
```
2121

22+
```{code-cell} ipython3
23+
nnx.use_hijax(True)
24+
25+
rngs = nnx.Rngs(0)
26+
model = nnx.Linear(2, 3, rngs=rngs)
27+
optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param)
28+
29+
@jax.jit
30+
def train_step(x, y):
31+
loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)
32+
loss, grads = jax.value_and_grad(loss_fn)(model) # tmp fix for jax.grad
33+
optimizer.update(model, grads)
34+
return loss
35+
36+
x, y = rngs.uniform((4, 2)), rngs.uniform((4, 3))
37+
for _ in range(3):
38+
print(train_step(x, y))
39+
```
40+
41+
## Hijax Variable
42+
43+
+++
44+
2245
State propagation:
2346

2447
```{code-cell} ipython3

flax/nnx/graph.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from flax.nnx.statelib import FlatState, State, map_state
3434
from flax.nnx.variablelib import Variable, is_array_ref, V
35-
from flax.typing import Key, PathParts, is_key_like
35+
from flax.typing import HashableMapping, Key, PathParts, is_key_like
3636
import jax
3737
import numpy as np
3838
import treescope # type: ignore[import-not-found,import-untyped]
@@ -301,50 +301,6 @@ def get_node_impl_for_type(
301301
return None
302302

303303

304-
class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
305-
_mapping: dict[HA, HB] | tp.Mapping[HA, HB]
306-
307-
def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True):
308-
self._mapping = dict(mapping) if copy else mapping
309-
310-
def __contains__(self, key: object) -> bool:
311-
return key in self._mapping
312-
313-
def __getitem__(self, key: HA) -> HB:
314-
return self._mapping[key]
315-
316-
def __iter__(self) -> tp.Iterator[HA]:
317-
return iter(self._mapping)
318-
319-
def __len__(self) -> int:
320-
return len(self._mapping)
321-
322-
def __hash__(self) -> int:
323-
# use type-aware sorting to support int keys
324-
def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
325-
key, _ = item
326-
if isinstance(key, int):
327-
return (0, key)
328-
elif isinstance(key, str):
329-
return (1, key)
330-
else:
331-
raise ValueError(f'Unsupported key type: {type(key)!r}')
332-
return hash(tuple(sorted(self._mapping.items(), key=_pytree__key_sort_fn)))
333-
334-
def __eq__(self, other: tp.Any) -> bool:
335-
return (
336-
isinstance(other, HashableMapping) and self._mapping == other._mapping
337-
)
338-
339-
def __repr__(self) -> str:
340-
return repr(self._mapping)
341-
342-
def update(self, other: tp.Mapping[HA, HB]) -> HashableMapping[HA, HB]:
343-
"""Updates the mapping with another mapping."""
344-
mapping = dict(self._mapping)
345-
mapping.update(other)
346-
return HashableMapping(mapping, copy=False)
347-
348304

349305
@jax.tree_util.register_static
350306
@dataclasses.dataclass(frozen=True, repr=False)

0 commit comments

Comments
 (0)