We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f74158b commit 806b4d3Copy full SHA for 806b4d3
ai_edge_torch/odml_torch/jax_bridge/_wrap.py
@@ -79,7 +79,13 @@ def lower_wrapper(*args):
79
nonlocal jax_lower_static_kwargs
80
81
jaxfn_args = []
82
- jaxfn_kwargs = jax_lower_static_kwargs.copy()
+ # TODO(junjiang): revert to jax_lower_static_kwargs.copy() once NumPy 2.0 is
83
+ # the minimum supported version.
84
+ jaxfn_kwargs = {
85
+ k: jax.numpy.array(v) if isinstance(v, float) else v
86
+ for k, v in jax_lower_static_kwargs.items()
87
+ }
88
+
89
for name, arg in zip(jax_lower_argnames, args):
90
if name is None:
91
jaxfn_args.append(arg)
0 commit comments