Fix HLO overlap tests adding proper decoding of while loops and conditional branches #10
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes failing
test_flash_bwd_sharded_hlotests (whenlocal=False/ring attention) by correcting the HLO decoder to properly traverse while loop bodies and conditional branches.The issue
Running
pytest tests/test_sharding.pyfails fortest_flash_bwd_sharded_hlowhenlocal=False(ring attention).The
decode_hlofunction was incomplete and couldn't see operations inside JAX'sscanloops or conditional branches as it only followedcalls=, missingbody=andcondition=used by while loops andbranch_computations={...}used by conditional statements. Thus the decoder output was justcollective-permute-start collective-permute-done, missing all thecustom-calloperations inside the loop. Hence, tests failed claiming no communication/computation overlap when overlap was actually working correctly.What's in this commit
decode_hlofunction to also followbody=,condition=andbranch_computations=count_overlapped_permutesfunction to compute number of overlapping and non-overlapping permutes.