We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0183a95 commit 8f8271eCopy full SHA for 8f8271e
flax/nnx/rnglib.py
@@ -124,6 +124,12 @@ def __call__(self) -> jax.Array:
124
self.count[...] += 1
125
return key
126
127
+ def key(self) -> jax.Array:
128
+ return self()
129
+
130
+ def split(self, k: int):
131
+ return self.fork(split=k)
132
133
def fork(self, *, split: int | tuple[int, ...] | None = None):
134
key = self()
135
if split is not None:
0 commit comments