-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVMGPUVectorDistribute] Support vector.mask + vector.multi_reduce #19880
base: main
Are you sure you want to change the base?
Conversation
@Groverkss is it fair to assume region nesting will be honored when distributing? |
Alright, I introduced |
08e4d7b
to
2ed8a16
Compare
2ed8a16
to
899abf1
Compare
899abf1
to
a8dcc8b
Compare
This commit enables vector layout propogation into and out of vector.mask and its body. Moreover, it enables the distribution of vector.multi_reduce that is wrapped in a vector.mask. The way that is done is : * The distributed mask is applied to thread-local reduce * The distributed opernad is selected between the reduction identity and the provided operand using the distributed mask. Signed-off-by: Manupa Karunaratne <[email protected]>
a hook to provide vector.mask { op } rewrites. This removes the rewrite ordering constraint that would otherwise be there where body op has to be distributed prior to mask op. Now, using this hook, developers could write masked op distribution pattern where pre-distribution mask op would be removed as part of the rewrite. Signed-off-by: Manupa Karunaratne <[email protected]>
a8dcc8b
to
920fb53
Compare
PTAL @qedawkins if you have sometime... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One main question about why we need both the local mask and the select, otherwise LGTM
std::function<void(DistributionLayout *, mlir::ChangeResult)> update) { | ||
mask.getBody()->walk( | ||
[&](Operation *traversed) { visitOperation(traversed); }); | ||
// Propogate from body to results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Propogate from body to results | |
// Propagate from body to results. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
mask = getDistributed(rewriter, maskOp.getMask(), maskLayout); | ||
Value passThruSrc = getCombiningIdentityValue( | ||
loc, rewriter, multiReduceOp.getKind(), disSrc.getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vector.mask
can carry its own pass_thru
, which I'm guessing goes here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skipped as discussed down below
loc, disSrc, localInit, distributedReductionMask, | ||
multiReduceOp.getKind()); | ||
if (mask) { | ||
localReduction = | ||
vector::maskOperation(rewriter, localReduction.getDefiningOp(), mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the arith.select
and the vector.mask
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed post-distribution masking for now as discussed.
// CHECK: %[[MASK_ITL_PCK:.+]] = vector.transpose %[[MASK_PCK]], [0, 3, 1, 4, 2, 5] : vector<2x2x2x1x1x8xi1> to vector<2x1x2x1x2x8xi1> | ||
|
||
// CHECK: %[[SELECT:.+]] = arith.select %[[MASK_ITL_PCK]], {{.*}}, %[[RED_IDENTITY]] : vector<2x1x2x1x2x8xi1>, vector<2x1x2x1x2x8xf16> | ||
// CHECK: vector.mask %[[MASK_ITL_PCK]] { vector.multi_reduction <add>, %[[SELECT]], {{.*}} [0, 2, 4] : vector<2x1x2x1x2x8xf16> to vector<1x1x8xf16> } : vector<2x1x2x1x2x8xi1> -> vector<1x1x8xf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test with pass_thru
on the vector.mask
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skipped as discussed down below
So the masking is for thread-local reductions. Re-thinking, maybe the init might cover that already -- I can give a go at removing the select. |
wait .. no if we want to support passthru then select is needed. |
yeah I think we need to keep the select (although also fine to just not support pass_thru right now) and we can try dropping the mask. |
I can add passthru but why drop the mask ? |
Isn't the mask redundant if we have the select? |
for e.g., If the thread-local reduction dimension is long (longer than a machine vector), would n't that be used to cut down instructions issued? |
It didn't look like the existing lowerings were doing that to me, but I might not have looked close enough. If it does work out like that, keeping the mask makes sense. I've mostly been asking because the mask was surprising to me, I can approve and leave it as a future exercise to determine whether it's useful. |
I spent time reading the upstream code as well and traces now. Thus I ll leave a comment here and remove post-distribution mask; just not to trip on anything. |
ah ok, well ignore the pass_thru then. Sounds good to me! We can always add it back later if it's better. |
Signed-off-by: Manupa Karunaratne <[email protected]>
This commit enables vector layout propogation
into and out of vector.mask and its body.
Moreover, it enables the distribution of vector.multi_reduce
that is wrapped in a vector.mask.
The way that is done is :
reduction identity and the provided operand using
the distributed mask.
depends on : #19830 (hence putting to draft until thats merged)