Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect broadcast elimination when both operands to elementwise are broadcasted #521

Open
pranavm-nvidia opened this issue Feb 21, 2025 · 2 comments
Labels
mlir-tensorrt Pull request for the mlir-tensorrt project

Comments

@pranavm-nvidia
Copy link
Collaborator

ones in Tripy works by creating a constant and broadcasting it up to the correct shape.

There seems to be a bug with how tensorrt.broadcast is optimized out in the case where both operands are broadcasted in this way:

x = tp.ones((2, 2))
y = tp.ones((2, 2))

print(x + y)

Input MLIR:

module @"outs_%t7_1" {
  func.func @main() -> tensor<?x?xf32> {
    %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<f32>
    %cst_i32 = tensorrt.constant dense<2> : tensor<2xi32>
    %0 = tensorrt.broadcast %cst_f32 broadcast_dims<> shape(%cst_i32 : tensor<2xi32>) : tensor<f32> to tensor<?x?xf32>
    %cst_f32_0 = tensorrt.constant dense<1.000000e+00> : tensor<f32>
    %cst_i32_1 = tensorrt.constant dense<2> : tensor<2xi32>
    %1 = tensorrt.broadcast %cst_f32_0 broadcast_dims<> shape(%cst_i32_1 : tensor<2xi32>) : tensor<f32> to tensor<?x?xf32>
    %2 = tensorrt.element_wise <kSUM>(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
    return %2 : tensor<?x?xf32>
  }
}

This turns into:

tensorrt.module @trt_engines {
  func.func @tensorrt_cluster() -> tensor<?x?xf32> {
    %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<1x1xf32>
    %0 = tensorrt.element_wise <kSUM>(%cst_f32, %cst_f32 : tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
    return %0 : tensor<?x?xf32>
  }
}

This incorrectly results in an output with a volume of 1:

tensor([[2.0000]], dtype=float32, loc=gpu:0, shape=(1, 1))

In this case, the broadcast instead needs to be rotated over the elementwise operation so that the output is still of the correct shape.

@pranavm-nvidia pranavm-nvidia added the mlir-tensorrt Pull request for the mlir-tensorrt project label Feb 21, 2025
@pranavm-nvidia
Copy link
Collaborator Author

pranavm-nvidia commented Feb 21, 2025

Seems like there is some logic that should address this case, but maybe it doesn't work for dynamic shapes?

    // You can't eliminate both broadcasts if the same unit-dim in both
    // operands is being broadcast to a larger value. We can do some further
    // simplification, but we leave that to other patterns.
    for (unsigned i = 0; i < input1.getType().getRank(); i++) {
      if (input1.getType().getDimSize(i) == 1 &&
          input2.getType().getDimSize(i) == 1 && op.getType().getDimSize(i) > 1)
        return failure();
    }

If dynamic dimensions are represented by negative integers, then maybe we could change the condition to: op.getType().getDimSize(i) != 1?

@pranavm-nvidia
Copy link
Collaborator Author

Fixed by #523

@pranavm-nvidia pranavm-nvidia changed the title Incorrect broadcast elimintation when both operands to elementwise are broadcasted Incorrect broadcast elimination when both operands to elementwise are broadcasted Feb 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir-tensorrt Pull request for the mlir-tensorrt project
Projects
None yet
Development

No branches or pull requests

1 participant