diff --git a/shardy/integrations/python/jax/mpmd/ops.py b/shardy/integrations/python/jax/mpmd/ops.py index 9e0c46a3..c8c2ed70 100644 --- a/shardy/integrations/python/jax/mpmd/ops.py +++ b/shardy/integrations/python/jax/mpmd/ops.py @@ -294,7 +294,9 @@ def custom_call_transpose(params, *rest, primitive=primitive): pe.partial_eval_jaxpr_custom_rules[primitives.call_p] ) state_discharge.register_discharge_rule(primitive)( - state_discharge._call_discharge_rule + functools.partial( + state_discharge._call_primitive_discharge_rule, primitive + ) ) return primitive