Add experimental Shardy support.#1642
Merged
phu0ngng merged 23 commits intoNVIDIA:mainfrom Apr 14, 2025
Merged
Conversation
Member
Author
|
/te-ci jax L1 |
phu0ngng
reviewed
Apr 7, 2025
phu0ngng
reviewed
Apr 7, 2025
phu0ngng
reviewed
Apr 7, 2025
Member
Author
|
PTAL. If this looks OK in principle, I'll do the rebase. |
66fb2b8 to
ee15c3f
Compare
Member
Author
|
/te-ci jax L1 |
Production use is not yet recommended. Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
test_distributed_layernorm_mlp contains some cases that require functionality that is missing from JAX. This needs to be fixed there before these tests can be enabled again. Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Member
Author
|
/te-ci jax L1 |
phu0ngng
reviewed
Apr 10, 2025
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
- Add support for variable flatten_axis values - Hopefully improve the names in BlockScalingModeMetadataImpl::get_shardy_sharding_rules a bit. Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
for more information, see https://pre-commit.ci
Member
Author
|
/te-ci jax L1 |
phu0ngng
requested changes
Apr 10, 2025
Collaborator
phu0ngng
left a comment
There was a problem hiding this comment.
The rest looks good to me!
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Collaborator
|
/te-ci jax L1 |
phu0ngng
reviewed
Apr 10, 2025
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Member
Author
|
/te-ci jax L1 |
phu0ngng
reviewed
Apr 11, 2025
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Collaborator
|
/te-ci jax L1 |
phu0ngng
approved these changes
Apr 14, 2025
Collaborator
|
Pipeline #26929185. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This is experimental and not intended for production use. If you try this, please file issues for problems you encounter.
Description
Shardy is a new partitioning system in JAX. It currently replaces the sharding propagation passes in GSPMD.
This PR adds Shardy support for TE's attention, normalization, quantization and softmax primitives.
To use this, simply enable Shardy:
Type of change
Changes
Please list the changes introduced in this PR: