Skip to content

Commit

Permalink
Add a wait_all op to collect tokens from fused channel putgets, to pr…
Browse files Browse the repository at this point in the history
…eserve dependency connectivity (Xilinx#886)
  • Loading branch information
erwei-xilinx authored Feb 1, 2025
1 parent e52a78a commit 6869192
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
10 changes: 9 additions & 1 deletion mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4121,7 +4121,15 @@ class AIRFuseChannels
remapAllParentLoopArgs(remap, a, b);
OpBuilder builder(a);
builder.setInsertionPointAfter(a);
cloneOpAndOperands(builder, remap, b);
auto new_b = cloneOpAndOperands(builder, remap, b);
if (air::isAsyncOp(a) && air::isAsyncOp(new_b)) {
auto newWaitAll = builder.create<air::WaitAllOp>(
a->getLoc(), air::AsyncTokenType::get(a->getContext()),
SmallVector<Value>{air::getAsyncTokenFromOp(a),
air::getAsyncTokenFromOp(new_b)});
air::getAsyncTokenFromOp(a).replaceAllUsesExcept(
newWaitAll.getAsyncToken(), newWaitAll);
}
// Erase b
if (air::isAsyncOp(b)) {
IRMapping waitAllRemap;
Expand Down
15 changes: 9 additions & 6 deletions mlir/test/Transform/AIRDependencyScheduleOpt/fuse_channels.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -968,8 +968,9 @@ module {
// AGGRESSIVE: air.channel.put{{.*}}@channel_5{{.*}} : (memref<256x32xi8>)
// AGGRESSIVE: air.segment @segment_0
// AGGRESSIVE: scf.for %{{.*}} = %c0{{.*}} to %c2{{.*}} step %c1{{.*}}
// AGGRESSIVE-NEXT: air.channel.get{{.*}}@channel_5{{.*}} : (memref<1x2x128x16xi8, 1 : i32>)
// AGGRESSIVE-NEXT: air.channel.get{{.*}}@channel_5{{.*}} : (memref<2x1x128x128xi8, 1 : i32>)
// AGGRESSIVE-NEXT: %[[TOK0:.*]] = air.channel.get{{.*}}@channel_5{{.*}} : (memref<1x2x128x16xi8, 1 : i32>)
// AGGRESSIVE-NEXT: %[[TOK1:.*]] = air.channel.get{{.*}}@channel_5{{.*}} : (memref<2x1x128x128xi8, 1 : i32>)
// AGGRESSIVE-NEXT: air.wait_all async [%[[TOK0]], %[[TOK1]]]
// AGGRESSIVE-NEXT: scf.yield
// AGGL1: air.launch
// AGGL1: air.channel.put{{.*}}@channel_4{{.*}} : (memref<512x256xi8>)
Expand Down Expand Up @@ -1313,8 +1314,9 @@ module {
// AGGL1-NEXT: air.channel.get{{.*}}@channel_6
// AGGL1-NEXT: air.channel.get{{.*}}@channel_6
// AGGL1: scf.for
// AGGL1-NEXT: air.channel.get{{.*}}@channel_4
// AGGL1-NEXT: air.channel.get{{.*}}@channel_5
// AGGL1-NEXT: %[[TOK0:.*]] = air.channel.get{{.*}}@channel_4
// AGGL1-NEXT: %[[TOK1:.*]] = air.channel.get{{.*}}@channel_5
// AGGL1-NEXT: air.wait_all async [%[[TOK0]], %[[TOK1]]]
// AGGL1-NEXT: scf.parallel
// AGGL1: air.channel.put{{.*}}@channel_6
// AGGL1: air.channel.put{{.*}}@channel_6
Expand All @@ -1339,8 +1341,9 @@ module {
// AGGRESSIVE-NEXT: air.channel.get{{.*}}@channel_6
// AGGRESSIVE-NEXT: air.channel.get{{.*}}@channel_6
// AGGRESSIVE: scf.for
// AGGRESSIVE-NEXT: air.channel.get{{.*}}@channel_4
// AGGRESSIVE-NEXT: air.channel.get{{.*}}@channel_4
// AGGRESSIVE-NEXT: %[[TOK0:.*]] = air.channel.get{{.*}}@channel_4
// AGGRESSIVE-NEXT: %[[TOK1:.*]] = air.channel.get{{.*}}@channel_4
// AGGRESSIVE-NEXT: air.wait_all async [%[[TOK0]], %[[TOK1]]]
// AGGRESSIVE-NEXT: scf.parallel
// AGGRESSIVE: air.channel.put{{.*}}@channel_6
// AGGRESSIVE: air.channel.put{{.*}}@channel_6
Expand Down

0 comments on commit 6869192

Please sign in to comment.