Skip to content

Commit 86f7bc8

Browse files
Jake VanderPlasKfacJaxDev
authored andcommitted
Avoid usage of deprecated jax.core APIs.
These APIs are deprecated as of JAX v0.10.0, replaced by equivalents in `jax.extend.core` (see https://docs.jax.dev/en/latest/jax.extend.html for details). PiperOrigin-RevId: 900270194
1 parent d687c27 commit 86f7bc8

3 files changed

Lines changed: 31 additions & 11 deletions

File tree

kfac_jax/_src/layers_and_loss_tags.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
import jax
2020
import jax.extend as jex
2121

22+
try:
23+
# JAX v0.10.0 and newer
24+
Effects: type[Any] = jex.core.Effects
25+
no_effects: Effects = jex.core.no_effects
26+
except AttributeError:
27+
# JAX v0.9.2 and older
28+
Effects = jax.core.Effects
29+
no_effects = jax.core.no_effects
30+
2231

2332
# Types for annotation
2433
T = TypeVar("T")
@@ -135,9 +144,9 @@ def abstract_eval(
135144
self,
136145
*args: Array,
137146
**params: Any,
138-
) -> tuple[Arrays, jax.core.Effects]:
147+
) -> tuple[Arrays, Effects]:
139148

140-
return get_loss_outputs(args, params), jax.core.no_effects
149+
return get_loss_outputs(args, params), no_effects
141150

142151
def _mlir_lowering(
143152
self,
@@ -336,10 +345,10 @@ def abstract_eval(
336345
self,
337346
*args: Array,
338347
**params: Any,
339-
) -> tuple[Array, jax.core.Effects]:
348+
) -> tuple[Array, Effects]:
340349
# For now we support only single output
341350
[output] = self.layer_data(args, params).outputs
342-
return output, jax.core.no_effects
351+
return output, no_effects
343352

344353
def _batching(
345354
self,

kfac_jax/_src/tag_graph_matcher.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,15 @@
3434
jax.__version_info__ if hasattr(jax, "__version_info__")
3535
else tuple(map(int, jax.__version__.split("."))))
3636

37-
if jax_version >= (0, 5, 1):
37+
if jax_version >= (0, 10, 0):
38+
DebugInfo = jex.core.DebugInfo
39+
DropVar = jex.core.DropVar
40+
elif jax_version >= (0, 5, 1):
3841
DebugInfo = jax.core.DebugInfo
42+
DropVar = jax.core.DropVar
3943
else:
4044
DebugInfo = jax.core.JaxprDebugInfo # pytype: disable=module-attr
45+
DropVar = jax.core.DropVar
4146

4247

4348
HIGHER_ORDER_NAMES = ("cond", "while", "scan", "pjit", "xla_call", "xla_pmap")
@@ -376,7 +381,7 @@ def make_jax_graph(
376381
new_out_vars = []
377382
for v in eqn.outvars:
378383

379-
if isinstance(v, jax.core.DropVar):
384+
if isinstance(v, DropVar):
380385
new_out_vars.append(make_var_func(v.aval))
381386
else:
382387
new_out_vars.append(v)
@@ -899,7 +904,7 @@ def read_env(
899904
if isinstance(v, jex.core.Literal):
900905
# Literals are values baked into the Jaxpr
901906
result.append(v.val)
902-
elif isinstance(v, jax.core.DropVar):
907+
elif isinstance(v, DropVar):
903908
result.append(None)
904909
else:
905910
result.append(env[v])

tests/test_graph_matcher.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,21 @@
2727
Array = kfac_jax.utils.Array
2828
Shape = kfac_jax.utils.Shape
2929

30+
try:
31+
# JAX v0.10.0 and newer
32+
DropVar = jex.core.DropVar
33+
except AttributeError:
34+
# JAX v0.9.2 and older
35+
DropVar = jax.core.DropVar
36+
3037

3138
class TestGraphMatcher(parameterized.TestCase):
3239
"""Test class for the functions in `tag_graph_matcher.py`."""
3340

3441
def check_equation_match(self, eqn1, vars_to_vars, vars_to_eqn):
3542
"""Checks that equation is matched in the other graph."""
3643

37-
eqn1_out_vars = [v for v in eqn1.outvars
38-
if not isinstance(v, jax.core.DropVar)]
44+
eqn1_out_vars = [v for v in eqn1.outvars if not isinstance(v, DropVar)]
3945
eqn2_out_vars = [vars_to_vars[v] for v in eqn1_out_vars]
4046
eqns = [vars_to_eqn[v] for v in eqn2_out_vars]
4147
self.assertTrue(all(e == eqns[0] for e in eqns[1:]))
@@ -124,8 +130,8 @@ def check_jaxpr_equal(self, jaxpr_1, jaxpr_2, map_output_vars: bool):
124130
for eqn1, eqn2 in zip(l1_eqns, l2_eqns):
125131
self.assertEqual(len(eqn1.outvars), len(eqn2.outvars))
126132
for v1, v2 in zip(eqn1.outvars, eqn2.outvars):
127-
if isinstance(v1, jax.core.DropVar):
128-
self.assertIsInstance(v2, jax.core.DropVar)
133+
if isinstance(v1, DropVar):
134+
self.assertIsInstance(v2, DropVar)
129135
elif isinstance(v1, jex.core.Literal):
130136
self.assertIsInstance(v2, jex.core.Literal)
131137
self.assertEqual(v1.aval, v2.aval)

0 commit comments

Comments
 (0)