Conversation
|
Review updated until commit 54bbceb Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Tests |
| ||||
| Bug fix |
| ||||
| Enhancement |
| ||||
| Miscellaneous |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Error Handling Robustness
continue to NVF_THROW in the sharding validation loop (line 393) could be overly strict. This throws an error when a parallel type is not sharded, but the original continue behavior suggests this might be a valid case. Need to verify this doesn't break legitimate use cases where some parallel types aren't sharded. |
for multi-GPU debugging. Multi-GPU scheduling happens before segmentation and the shardings are encoded as loop transforms.
I have been pushing the implementation forward using several hacks to uncover edge cases. Below are the key issues identified and their current status:
1. Sharding Propagation Rework
2. Multi-Dimensional Sharding &
getCommunicationInfogetCommunicationInfoto support multi-dimensional sharding. It reuseshaveDifferentShardingsto identify inconsistencies between input and outputTensorViewobjects. The commit needs cleanup and further test verification to be merged.haveDifferentShardingsis currently bottlenecked by the expensiveExpressionSimplifier. We need to transition this to be IdModel-based in a future iteration.3. Misaligned Memory Access in Transpose Kernels
ReorderShardedAxisPassto ensure the scattered axis of theReduceScatteris allocated outermost.4. Performance Bottleneck: AllGather memory
AllGatherpreceding the Einsum is functional but consumes too much memory for AlphaFold3 workloads due to long sequence lengths.cc @DejunL