Skip to content

Commit 806b4d3

Browse files
junjiang-labcopybara-github
authored andcommitted
Fix JAX bridge incompatibility with NumPy < 2.0
PiperOrigin-RevId: 823564334
1 parent f74158b commit 806b4d3

File tree

1 file changed

+7
-1
lines changed
  • ai_edge_torch/odml_torch/jax_bridge

1 file changed

+7
-1
lines changed

ai_edge_torch/odml_torch/jax_bridge/_wrap.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,13 @@ def lower_wrapper(*args):
7979
nonlocal jax_lower_static_kwargs
8080

8181
jaxfn_args = []
82-
jaxfn_kwargs = jax_lower_static_kwargs.copy()
82+
# TODO(junjiang): revert to jax_lower_static_kwargs.copy() once NumPy 2.0 is
83+
# the minimum supported version.
84+
jaxfn_kwargs = {
85+
k: jax.numpy.array(v) if isinstance(v, float) else v
86+
for k, v in jax_lower_static_kwargs.items()
87+
}
88+
8389
for name, arg in zip(jax_lower_argnames, args):
8490
if name is None:
8591
jaxfn_args.append(arg)

0 commit comments

Comments
 (0)