-
Notifications
You must be signed in to change notification settings - Fork 85
Add pass for register custom sharding rule #5851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pass registers custom sharding rules in the current MLIRContext so that Shardy can correctly propagate shardings through operations that do not have built-in sharding rules in Shardy. Starting from stablehlo.custom_call op, we can define any newly introduced op by our own rule.
| getCustomCallShardingRule(mlir::stablehlo::CustomCallOp op) const { | ||
| llvm::StringRef target = op.getCallTargetName(); | ||
|
|
||
| auto it = customCallShardingRules.find(target); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use llvm::DenseMap::lookup here to get the rule directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it to lookup
| @@ -1,8 +1,4 @@ | |||
| // REQUIRES: stablehlo | |||
| // This file incorporates work covered by the following copyright and permission notice: | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this isn’t directly related to this PR, but I included it since it’s a simple change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC this copyright notice was added as this test (or part of it) was copied from some other project. I think @tapspatel had added this. I think, we should keep this copyright if it is still the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test file is something I added in an earlier commit, and it looks like this copyright line got carried over while copy-pasting. @tapspatel, to be sure, do you think we should remove it? I don’t see the same notice in other test files.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #5851 +/- ##
==========================================
+ Coverage 69.51% 69.62% +0.11%
==========================================
Files 327 328 +1
Lines 48692 48888 +196
==========================================
+ Hits 33848 34039 +191
- Misses 14844 14849 +5 ☔ View full report in Codecov by Sentry. |
|
This is awesome @sshonTT. Do we want to use a custom or composite of for SDPA? You can probably get this in for paged attention? The sharding rules will be the same. |
change method
|
@AleksKnezevic I think using both should be fine, since they seem to map to a single TTNN op anyway. Of course, we’ll need to check the performance to be sure. |
|
@AleksKnezevic I thought about it again, and I think the custom call approach might actually give us more benefits unless head sharding is confirmed. With a composite op, if sharding propagation determines that a CCL operation needs to be inserted inside the composite function, the whole thing ends up running as flattened ops. |
|
@sshonTT, that makes sense. If we find the way we handle composite ops doesn't work for some composites, is it possible to handle them through this flow (i.e. if the composite op is in some table then apply a custom rule to it, otherwise flatten and propagate)? This is not a change I'm requesting now, just something to think about. |
|
@AleksKnezevic Yes, I think that would be possible. Adding new rules isn’t difficult with the current setup, and we would just need to register exceptions for the cases where flattening shouldn’t apply. |
| @@ -1,8 +1,4 @@ | |||
| // REQUIRES: stablehlo | |||
| // This file incorporates work covered by the following copyright and permission notice: | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC this copyright notice was added as this test (or part of it) was copied from some other project. I think @tapspatel had added this. I think, we should keep this copyright if it is still the case.
| /*isBlocked=*/usedByRngBitGenerator) | ||
| .build(); | ||
| } | ||
| + // Check if the custom call implements the ShardingRuleOpInterface. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems upstream-able so I updated a patch to Shardy (openxla/shardy#885)
nsmithtt
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving cmake changes!
### Ticket N/A ### Problem description We don't have sharding rule for custom ops, so we cannot run it on multi device ### What's changed This pass registers custom sharding rules in the current MLIRContext so that Shardy can correctly propagate shardings through operations that do not have built-in sharding rules in Shardy. Starting from stablehlo.custom_call op, we can define any newly introduced op by our own rule. ### Checklist - [ ] New/Existing tests provide coverage for changes
Ticket
N/A
Problem description
We don't have sharding rule for custom ops, so we cannot run it on multi device
What's changed
This pass registers custom sharding rules in the current MLIRContext so that Shardy can correctly propagate shardings through operations that do not have built-in sharding rules in Shardy. Starting from stablehlo.custom_call op, we can define any newly introduced op by our own rule.
Checklist