Skip to content

Conversation

@rfbr
Copy link

@rfbr rfbr commented Oct 13, 2025

Fixes failing test_flash_bwd_sharded_hlo tests (when local=False/ring attention) by correcting the HLO decoder to properly traverse while loop bodies and conditional branches.

The issue

Running pytest tests/test_sharding.py fails for test_flash_bwd_sharded_hlo when local=False (ring attention).
The decode_hlo function was incomplete and couldn't see operations inside JAX's scan loops or conditional branches as it only followed calls=, missing body= and condition= used by while loops and branch_computations={...} used by conditional statements. Thus the decoder output was just collective-permute-start collective-permute-done, missing all the custom-call operations inside the loop. Hence, tests failed claiming no communication/computation overlap when overlap was actually working correctly.

What's in this commit

  • Improved decode_hlo function to also follow body=, condition= and branch_computations=
  • Implemented a count_overlapped_permutes function to compute number of overlapping and non-overlapping permutes.
  • Updated test assertions for ring attention (fwd pass should've (N-1) overlapped rotations, 0 adjacent and bwd pass N overlapped rotations + 1 adjacent for the final gradient return)
  • Constrained sharding tests to 2 devices to avoid GQA incompatibility (when using GQA or MQA, head dimension sharding requires the number of devices to divide evenly into the group size)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant