Skip to content

Commit 19d8652

Browse files
sharadmvcopybara-github
authored andcommitted
Plumb prim params for call discharge rule (to handle named_computation_p
from shardy). PiperOrigin-RevId: 853813788
1 parent 6bc665a commit 19d8652

File tree

1 file changed

+3
-1
lines changed
  • shardy/integrations/python/jax/mpmd

1 file changed

+3
-1
lines changed

shardy/integrations/python/jax/mpmd/ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,9 @@ def custom_call_transpose(params, *rest, primitive=primitive):
294294
pe.partial_eval_jaxpr_custom_rules[primitives.call_p]
295295
)
296296
state_discharge.register_discharge_rule(primitive)(
297-
state_discharge._call_discharge_rule
297+
functools.partial(
298+
state_discharge._call_primitive_discharge_rule, primitive
299+
)
298300
)
299301
return primitive
300302

0 commit comments

Comments
 (0)