Skip to content

Conversation

@sshonTT
Copy link
Contributor

@sshonTT sshonTT commented Nov 13, 2025

Ticket

N/A

Problem description

We don't have sharding rule for custom ops, so we cannot run it on multi device

What's changed

This pass registers custom sharding rules in the current MLIRContext so that Shardy can correctly propagate shardings through operations that do not have built-in sharding rules in Shardy. Starting from stablehlo.custom_call op, we can define any newly introduced op by our own rule.

Checklist

  • New/Existing tests provide coverage for changes

This pass registers custom sharding rules in the current MLIRContext so that Shardy can correctly propagate shardings through operations that do not have built-in sharding rules in Shardy.
Starting from stablehlo.custom_call op, we can define any newly introduced op by our own rule.
getCustomCallShardingRule(mlir::stablehlo::CustomCallOp op) const {
llvm::StringRef target = op.getCallTargetName();

auto it = customCallShardingRules.find(target);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use llvm::DenseMap::lookup here to get the rule directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to lookup

@@ -1,8 +1,4 @@
// REQUIRES: stablehlo
// This file incorporates work covered by the following copyright and permission notice:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this isn’t directly related to this PR, but I included it since it’s a simple change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC this copyright notice was added as this test (or part of it) was copied from some other project. I think @tapspatel had added this. I think, we should keep this copyright if it is still the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test file is something I added in an earlier commit, and it looks like this copyright line got carried over while copy-pasting. @tapspatel, to be sure, do you think we should remove it? I don’t see the same notice in other test files.

@codecov-commenter
Copy link

codecov-commenter commented Nov 14, 2025

Codecov Report

❌ Patch coverage is 91.11111% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.62%. Comparing base (594aa8e) to head (907ecca).
⚠️ Report is 5 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
...tableHLO/Transforms/RegisterCustomShardingRule.cpp 90.90% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #5851      +/-   ##
==========================================
+ Coverage   69.51%   69.62%   +0.11%     
==========================================
  Files         327      328       +1     
  Lines       48692    48888     +196     
==========================================
+ Hits        33848    34039     +191     
- Misses      14844    14849       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@AleksKnezevic
Copy link
Contributor

This is awesome @sshonTT. Do we want to use a custom or composite of for SDPA? You can probably get this in for paged attention? The sharding rules will be the same.

@sshonTT
Copy link
Contributor Author

sshonTT commented Nov 14, 2025

@AleksKnezevic I think using both should be fine, since they seem to map to a single TTNN op anyway. Of course, we’ll need to check the performance to be sure.
And as you mentioned, if the same sharding rules apply, we should be able to map them in the same way for paged attention as well.

@sshonTT
Copy link
Contributor Author

sshonTT commented Nov 14, 2025

@AleksKnezevic I thought about it again, and I think the custom call approach might actually give us more benefits unless head sharding is confirmed. With a composite op, if sharding propagation determines that a CCL operation needs to be inserted inside the composite function, the whole thing ends up running as flattened ops.
On the other hand, with a custom op, the CCL operations would be inserted before the custom call for replication, but the core computation can still run as a single TTNN op.

@AleksKnezevic
Copy link
Contributor

@sshonTT, that makes sense. If we find the way we handle composite ops doesn't work for some composites, is it possible to handle them through this flow (i.e. if the composite op is in some table then apply a custom rule to it, otherwise flatten and propagate)?

This is not a change I'm requesting now, just something to think about.

@sshonTT
Copy link
Contributor Author

sshonTT commented Nov 14, 2025

@AleksKnezevic Yes, I think that would be possible. Adding new rules isn’t difficult with the current setup, and we would just need to register exceptions for the cases where flattening shouldn’t apply.

@@ -1,8 +1,4 @@
// REQUIRES: stablehlo
// This file incorporates work covered by the following copyright and permission notice:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC this copyright notice was added as this test (or part of it) was copied from some other project. I think @tapspatel had added this. I think, we should keep this copyright if it is still the case.

/*isBlocked=*/usedByRngBitGenerator)
.build();
}
+ // Check if the custom call implements the ShardingRuleOpInterface.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems upstream-able so I updated a patch to Shardy (openxla/shardy#885)

@sshonTT
Copy link
Contributor Author

sshonTT commented Nov 14, 2025

@vwellsTT, @nsmithtt Could you please take a look as code owners?

Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving cmake changes!

@sshonTT sshonTT merged commit a383e50 into main Nov 17, 2025
32 checks passed
@sshonTT sshonTT deleted the sshon/pr-own-sharding-rule branch November 17, 2025 15:18
abogdanovicTT pushed a commit that referenced this pull request Nov 27, 2025
### Ticket
N/A

### Problem description
We don't have sharding rule for custom ops, so we cannot run it on multi
device

### What's changed
This pass registers custom sharding rules in the current MLIRContext so
that Shardy can correctly propagate shardings through operations that do
not have built-in sharding rules in Shardy. Starting from
stablehlo.custom_call op, we can define any newly introduced op by our
own rule.

### Checklist
- [ ] New/Existing tests provide coverage for changes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants