-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DispatchCreation] Modify the generated fused op to not use concats. #19980
[DispatchCreation] Modify the generated fused op to not use concats. #19980
Conversation
It might be better to just review the new changes by themselves and ignore the diff. The pass is essentially rewritten. |
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
9ebc531
to
bfca1d5
Compare
There's a problem with how ops are grouped. util.func public @test_partial_horizontal_fuse(%arg0: tensor<640x640xf32>, %arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>, %arg3: tensor<640x640xf32>) -> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<640x640xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
%3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
%4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>
} This isn't directly related to the changes you made, I think this problem is on main too. |
Good catch. Let me see if I can fix that. |
compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are (already) missing the documentation of the pass. The PR description looks good to me. Can you add such documentation to Passes.td
? I.e., add a description
section to the pass definition.
iree/compiler/src/iree/compiler/DispatchCreation/Passes.td
Lines 68 to 84 in 6ebfcaa
def FuseHorizontalContractionsPass: | |
InterfacePass<"iree-dispatch-creation-fuse-horizontal-contractions", "mlir::FunctionOpInterface"> { | |
let summary = "Fuses horizontal contraction ops without fusions"; | |
let dependentDialects = [ | |
"mlir::arith::ArithDialect", | |
"mlir::tensor::TensorDialect", | |
]; | |
let options = [ | |
Option<"fusionLimit", "fusion-limit", "int", | |
/*default=*/"3", "Maximum number of contractions fused into one"> | |
]; | |
let statistics = [ | |
Statistic<"numFusionGroups", "num-fusion-groups", "Number of fusion groups found">, | |
Statistic<"numSize2FusionGroups", "num-size-2-groups", "Number of fusion groups of size 2">, | |
Statistic<"numSize3FusionGroups", "num-size-3-groups", "Number of fusion groups of size 3"> | |
]; | |
} |
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
bfca1d5
to
dcfa537
Compare
@IanWood1 pushed a fix for this issue. |
I merged this into #19847 (I think these changes enable more horizontal fusion in punet) and got a few failing dispatches https://gist.github.com/IanWood1/2ddd601970b9d0197cf01aa91346e7e8. They are smaller sized so I think there going down a different pipeline |
Really. I have been trying this on punet locally. I didnt see any issue there. |
968c63a
to
dd5341d
Compare
Ok, I understand what you are saying now. I think we will need the tile and fuse pipeline to handle this operation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just few nits + a question about a check.
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, LGTM % nits!
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
dd5341d
to
2e943c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reviews!
The change also allows doing horizontal fusion in cases where the LHS operand is the same, but the RHS/Outputs might be transposed. Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
… this yet. Previous implementation of horizontal fusion missed opportunities for horizontal fusion in SD3, but now they do get picked up, but the backend doesnt work on these. Dropping the flag is a no-op for the test since there was no horizontal fusion to start with. Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
2e943c9
to
ea20fec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, just some optional nits.
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Show resolved
Hide resolved
Signed-off-by: MaheshRavishankar <[email protected]>
ea20fec
to
ac77783
Compare
This is an almost complete rewrite of the pass to fuse contractions horizontally which instead of concatenating operands to map to a GEMM, followed by slices to extract the individual matmul results; the pass now just creates a new operation with the operands being the common LHS, the RHS of each of the gemms, and the output of each of the gemms. The generated op yields the result of each constituent matmul.
This also allows for the RHS/output indexing maps of the gemms to be mismatched, since only the LHS operand and indexing maps need to match. The change also permutes the iteration space of the gemms to ensure that the same indexing maps are used for the LHS across all the fused matmuls.
The rest of the compiler stack has already been fixed up to handle such operations.