-
Notifications
You must be signed in to change notification settings - Fork 560
feat: abstraction of xla::OpSharding proto using wrapper class #9467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kvshbg-aws
wants to merge
13
commits into
pytorch:master
Choose a base branch
from
kvshbg-aws:kvshbg-aws/local-spmd-abstraction
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
feat: abstraction of xla::OpSharding proto using wrapper class #9467
kvshbg-aws
wants to merge
13
commits into
pytorch:master
from
kvshbg-aws:kvshbg-aws/local-spmd-abstraction
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
d0502ab
to
7fc15ea
Compare
7c4a3cd
to
1d55ae9
Compare
1ddbb1b
to
2756c1a
Compare
2756c1a
to
4556faa
Compare
pgmoka
reviewed
Aug 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left comments. Overall, the PR looks very good.
…_assignment() is empty
4752558
to
a8458ac
Compare
LGTM pending tests |
qihqi
approved these changes
Sep 4, 2025
pgmoka
approved these changes
Sep 9, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR includes the changes related to abstracting
xla::OpSharidng
proto object into atorch_xla::OpSharding
wrapper class.This new class object will not have the requirements of xla::OpSharding (however, it will be an extension xla::OpSharding proto defined over here).
We have defined the wrapper class in torch/xla which will construct an xla::OpSharding object with additional fields such as global_device_ids/global_tile_assignment and will have forwarded/proxy functions to xla::OpSharding . These forwarded functions will help user still make use of the same
xla::OpSharding
APIs as they normally would. We can also define torch_xla specific functions in this wrapper class to further use the extra fields that were stored during the initialization of the OpSharding object. This approach also allows the flexibility of converting the torch_xla::OpSharding object back to xla::OpSharding while lowering into HLO, thus, giving user the flexibility to use the abstracted class (and other additional fields stored) anywhere in the code base as needed, this is particularly useful since the XLA's HLOs are 0th indexed, hence we need to use the normalized_device_ids (starting from index 0) when lowering the program into the HLO, whereas we can still use the denormalized/global_device_ids in other places such as inside pjrt client to set the device_assignment using the user specified device_ids.Component diagram for reference -

Ref issue - #9390