Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DispatchCreation] Modify the generated fused op to not use concats. #19980

Conversation

MaheshRavishankar
Copy link
Contributor

This is an almost complete rewrite of the pass to fuse contractions horizontally which instead of concatenating operands to map to a GEMM, followed by slices to extract the individual matmul results; the pass now just creates a new operation with the operands being the common LHS, the RHS of each of the gemms, and the output of each of the gemms. The generated op yields the result of each constituent matmul.
This also allows for the RHS/output indexing maps of the gemms to be mismatched, since only the LHS operand and indexing maps need to match. The change also permutes the iteration space of the gemms to ensure that the same indexing maps are used for the LHS across all the fused matmuls.

The rest of the compiler stack has already been fixed up to handle such operations.

@MaheshRavishankar
Copy link
Contributor Author

It might be better to just review the new changes by themselves and ignore the diff. The pass is essentially rewritten.

@IanWood1
Copy link
Contributor

There's a problem with how ops are grouped. allOps never gets updated with the other ops determined to be fusible with the root op. Also, the candidates to fuse need to be iterated over in dominance order to ensure that, using the example below, %3 gets grouped before %4

util.func public @test_partial_horizontal_fuse(%arg0: tensor<640x640xf32>, %arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>, %arg3: tensor<640x640xf32>) -> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<640x640xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>
}

This isn't directly related to the changes you made, I think this problem is on main too.

@MaheshRavishankar
Copy link
Contributor Author

There's a problem with how ops are grouped. allOps never gets updated with the other ops determined to be fusible with the root op. Also, the candidates to fuse need to be iterated over in dominance order to ensure that, using the example below, %3 gets grouped before %4

util.func public @test_partial_horizontal_fuse(%arg0: tensor<640x640xf32>, %arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>, %arg3: tensor<640x640xf32>) -> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<640x640xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>
}

This isn't directly related to the changes you made, I think this problem is on main too.

Good catch. Let me see if I can fix that.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are (already) missing the documentation of the pass. The PR description looks good to me. Can you add such documentation to Passes.td? I.e., add a description section to the pass definition.

def FuseHorizontalContractionsPass:
InterfacePass<"iree-dispatch-creation-fuse-horizontal-contractions", "mlir::FunctionOpInterface"> {
let summary = "Fuses horizontal contraction ops without fusions";
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::tensor::TensorDialect",
];
let options = [
Option<"fusionLimit", "fusion-limit", "int",
/*default=*/"3", "Maximum number of contractions fused into one">
];
let statistics = [
Statistic<"numFusionGroups", "num-fusion-groups", "Number of fusion groups found">,
Statistic<"numSize2FusionGroups", "num-size-2-groups", "Number of fusion groups of size 2">,
Statistic<"numSize3FusionGroups", "num-size-3-groups", "Number of fusion groups of size 3">
];
}

@MaheshRavishankar MaheshRavishankar force-pushed the shared/noconcatHorizontalFusionChanges branch from bfca1d5 to dcfa537 Compare February 14, 2025 19:56
@MaheshRavishankar
Copy link
Contributor Author

There's a problem with how ops are grouped. allOps never gets updated with the other ops determined to be fusible with the root op. Also, the candidates to fuse need to be iterated over in dominance order to ensure that, using the example below, %3 gets grouped before %4

util.func public @test_partial_horizontal_fuse(%arg0: tensor<640x640xf32>, %arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>, %arg3: tensor<640x640xf32>) -> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<640x640xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>
}

This isn't directly related to the changes you made, I think this problem is on main too.

@IanWood1 pushed a fix for this issue.

@IanWood1
Copy link
Contributor

IanWood1 commented Feb 14, 2025

I merged this into #19847 (I think these changes enable more horizontal fusion in punet) and got a few failing dispatches https://gist.github.com/IanWood1/2ddd601970b9d0197cf01aa91346e7e8.

They are smaller sized so I think there going down a different pipeline

@MaheshRavishankar
Copy link
Contributor Author

I merged this into #19847 (I think these changes enable more horizontal fusion in punet) and got a few failing dispatches https://gist.github.com/IanWood1/2ddd601970b9d0197cf01aa91346e7e8.

They are smaller sized so I think there going down a different pipeline

Really. I have been trying this on punet locally. I didnt see any issue there.

@MaheshRavishankar MaheshRavishankar force-pushed the shared/noconcatHorizontalFusionChanges branch from 968c63a to dd5341d Compare February 15, 2025 22:02
@MaheshRavishankar
Copy link
Contributor Author

I merged this into #19847 (I think these changes enable more horizontal fusion in punet) and got a few failing dispatches https://gist.github.com/IanWood1/2ddd601970b9d0197cf01aa91346e7e8.

They are smaller sized so I think there going down a different pipeline

Ok, I understand what you are saying now. I think we will need the tile and fuse pipeline to handle this operation.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just few nits + a question about a check.

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, LGTM % nits!

@MaheshRavishankar MaheshRavishankar force-pushed the shared/noconcatHorizontalFusionChanges branch from dd5341d to 2e943c9 Compare February 18, 2025 00:31
Copy link
Contributor Author

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reviews!

MaheshRavishankar and others added 4 commits February 17, 2025 18:37
The change also allows doing horizontal fusion in cases where the LHS
operand is the same, but the RHS/Outputs might be transposed.

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
… this yet.

Previous implementation of horizontal fusion missed opportunities for
horizontal fusion in SD3, but now they do get picked up, but the
backend doesnt work on these. Dropping the flag is a no-op for the
test since there was no horizontal fusion to start with.

Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the shared/noconcatHorizontalFusionChanges branch from 2e943c9 to ea20fec Compare February 18, 2025 00:38
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, just some optional nits.

Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the shared/noconcatHorizontalFusionChanges branch from ea20fec to ac77783 Compare February 18, 2025 21:22
@MaheshRavishankar MaheshRavishankar merged commit b85c180 into iree-org:main Feb 18, 2025
40 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants