Skip to content

Comments

Add experimental Shardy support.#1642

Merged
phu0ngng merged 23 commits intoNVIDIA:mainfrom
jreiffers:shardyv3
Apr 14, 2025
Merged

Add experimental Shardy support.#1642
phu0ngng merged 23 commits intoNVIDIA:mainfrom
jreiffers:shardyv3

Conversation

@jreiffers
Copy link
Member

@jreiffers jreiffers commented Apr 3, 2025

This is experimental and not intended for production use. If you try this, please file issues for problems you encounter.

Description

Shardy is a new partitioning system in JAX. It currently replaces the sharding propagation passes in GSPMD.

This PR adds Shardy support for TE's attention, normalization, quantization and softmax primitives.

To use this, simply enable Shardy:

jax.config.update("jax_use_shardy_partitioner", True)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Adds functions to provide Shardy sharding rules to primitives.

@jreiffers
Copy link
Member Author

/te-ci jax L1

@phu0ngng phu0ngng self-requested a review April 7, 2025 12:45
@jreiffers
Copy link
Member Author

PTAL. If this looks OK in principle, I'll do the rebase.

@jreiffers jreiffers force-pushed the shardyv3 branch 2 times, most recently from 66fb2b8 to ee15c3f Compare April 9, 2025 12:47
@jreiffers
Copy link
Member Author

/te-ci jax L1

jreiffers and others added 13 commits April 10, 2025 11:44
Production use is not yet recommended.

Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
test_distributed_layernorm_mlp contains some cases that require
functionality that is missing from JAX. This needs to be fixed
there before these tests can be enabled again.

Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
@jreiffers
Copy link
Member Author

/te-ci jax L1

Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
jreiffers and others added 2 commits April 10, 2025 13:31
- Add support for variable flatten_axis values
- Hopefully improve the names in BlockScalingModeMetadataImpl::get_shardy_sharding_rules
  a bit.

Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
@jreiffers
Copy link
Member Author

/te-ci jax L1

Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest looks good to me!

Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
@phu0ngng
Copy link
Collaborator

/te-ci jax L1

Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
@jreiffers
Copy link
Member Author

/te-ci jax L1

Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
Signed-off-by: Johannes Reifferscheid <jreiffers@nvidia.com>
@phu0ngng
Copy link
Collaborator

/te-ci jax L1

@phu0ngng
Copy link
Collaborator

Pipeline #26929185.

@phu0ngng phu0ngng merged commit 6117b20 into NVIDIA:main Apr 14, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants