|
32 | 32 | ) |
33 | 33 | from flax.nnx.statelib import FlatState, State, map_state |
34 | 34 | from flax.nnx.variablelib import Variable, is_array_ref, V |
35 | | -from flax.typing import Key, PathParts, is_key_like |
| 35 | +from flax.typing import HashableMapping, Key, PathParts, is_key_like |
36 | 36 | import jax |
37 | 37 | import numpy as np |
38 | 38 | import treescope # type: ignore[import-not-found,import-untyped] |
@@ -301,50 +301,6 @@ def get_node_impl_for_type( |
301 | 301 | return None |
302 | 302 |
|
303 | 303 |
|
304 | | -class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): |
305 | | - _mapping: dict[HA, HB] | tp.Mapping[HA, HB] |
306 | | - |
307 | | - def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True): |
308 | | - self._mapping = dict(mapping) if copy else mapping |
309 | | - |
310 | | - def __contains__(self, key: object) -> bool: |
311 | | - return key in self._mapping |
312 | | - |
313 | | - def __getitem__(self, key: HA) -> HB: |
314 | | - return self._mapping[key] |
315 | | - |
316 | | - def __iter__(self) -> tp.Iterator[HA]: |
317 | | - return iter(self._mapping) |
318 | | - |
319 | | - def __len__(self) -> int: |
320 | | - return len(self._mapping) |
321 | | - |
322 | | - def __hash__(self) -> int: |
323 | | - # use type-aware sorting to support int keys |
324 | | - def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]: |
325 | | - key, _ = item |
326 | | - if isinstance(key, int): |
327 | | - return (0, key) |
328 | | - elif isinstance(key, str): |
329 | | - return (1, key) |
330 | | - else: |
331 | | - raise ValueError(f'Unsupported key type: {type(key)!r}') |
332 | | - return hash(tuple(sorted(self._mapping.items(), key=_pytree__key_sort_fn))) |
333 | | - |
334 | | - def __eq__(self, other: tp.Any) -> bool: |
335 | | - return ( |
336 | | - isinstance(other, HashableMapping) and self._mapping == other._mapping |
337 | | - ) |
338 | | - |
339 | | - def __repr__(self) -> str: |
340 | | - return repr(self._mapping) |
341 | | - |
342 | | - def update(self, other: tp.Mapping[HA, HB]) -> HashableMapping[HA, HB]: |
343 | | - """Updates the mapping with another mapping.""" |
344 | | - mapping = dict(self._mapping) |
345 | | - mapping.update(other) |
346 | | - return HashableMapping(mapping, copy=False) |
347 | | - |
348 | 304 |
|
349 | 305 | @jax.tree_util.register_static |
350 | 306 | @dataclasses.dataclass(frozen=True, repr=False) |
|
0 commit comments