Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

potentially incorrect transformation? #241

Closed
avik-pal opened this issue Jan 16, 2025 · 4 comments
Closed

potentially incorrect transformation? #241

avik-pal opened this issue Jan 16, 2025 · 4 comments
Assignees

Comments

@avik-pal
Copy link
Collaborator

I will try to reduce this but opening an initial version for now

module {
  func.func private @"*_broadcast_scalar"(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"*_broadcast_scalar1"(%arg0: tensor<i64>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i64>, tensor<f32>) {
    %0 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<f32>
    %1 = stablehlo.multiply %0, %arg1 : tensor<f32>
    return %1, %arg0, %arg1 : tensor<f32>, tensor<i64>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func @main(%arg0: tensor<3x12x4xf32>) -> (tensor<3x12x12xf32>, tensor<3x12x4xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x12x4xf32>) -> tensor<4x12x3xf32>
    %1:2 = enzyme.batch @"*_broadcast_scalar"(%0) {batch_shape = array<i64: 4, 12, 3>} : (tensor<4x12x3xf32>) -> (tensor<4x12x3xf32>, tensor<4x12x3xf32>)
    %2 = stablehlo.convert %1#0 : tensor<4x12x3xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %3 = enzyme.batch @identity_broadcast_scalar(%2) {batch_shape = array<i64: 4, 12, 3>} : (tensor<4x12x3xf32>) -> tensor<4x12x3xf32>
    %4 = stablehlo.convert %3 : tensor<4x12x3xf32>
    %5 = stablehlo.reduce(%4 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<4x12x3xf32>, tensor<f32>) -> tensor<12x3xf32>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<12x3xf32>) -> tensor<3x12xf32>
    %7 = stablehlo.reshape %6 : (tensor<3x12xf32>) -> tensor<3x12x1xf32>
    %8 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<3x12x1xf32>) -> tensor<1x12x3xf32>
    %9 = stablehlo.transpose %1#1, dims = [1, 0, 2] : (tensor<4x12x3xf32>) -> tensor<12x4x3xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x3xf32>
    %10 = stablehlo.transpose %9, dims = [2, 0, 1] : (tensor<12x4x3xf32>) -> tensor<3x12x4xf32>
    %11 = stablehlo.transpose %1#1, dims = [2, 0, 1] : (tensor<4x12x3xf32>) -> tensor<3x4x12xf32>
    %12 = stablehlo.convert %10 : tensor<3x12x4xf32>
    %13 = stablehlo.convert %11 : tensor<3x4x12xf32>
    %14 = stablehlo.dot_general %12, %13, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x12x4xf32>, tensor<3x4x12xf32>) -> tensor<3x12x12xf32>
    %15 = stablehlo.transpose %14, dims = [1, 2, 0] : (tensor<3x12x12xf32>) -> tensor<12x12x3xf32>
    %c = stablehlo.constant dense<2> : tensor<12x12x3xi64>
    %16:3 = enzyme.batch @"*_broadcast_scalar1"(%c, %15) {batch_shape = array<i64: 12, 12, 3>} : (tensor<12x12x3xi64>, tensor<12x12x3xf32>) -> (tensor<12x12x3xf32>, tensor<12x12x3xi64>, tensor<12x12x3xf32>)
    %17 = stablehlo.convert %16#0 : tensor<12x12x3xf32>
    %18 = stablehlo.transpose %8, dims = [1, 0, 2] : (tensor<1x12x3xf32>) -> tensor<12x1x3xf32>
    %19 = stablehlo.broadcast_in_dim %8, dims = [0, 1, 2] : (tensor<1x12x3xf32>) -> tensor<12x12x3xf32>
    %20:3 = enzyme.batch @"+_broadcast_scalar"(%19, %17) {batch_shape = array<i64: 12, 12, 3>} : (tensor<12x12x3xf32>, tensor<12x12x3xf32>) -> (tensor<12x12x3xf32>, tensor<12x12x3xf32>, tensor<12x12x3xf32>)
    %21 = stablehlo.convert %20#0 : tensor<12x12x3xf32>
    %22 = stablehlo.transpose %21, dims = [2, 1, 0] : (tensor<12x12x3xf32>) -> tensor<3x12x12xf32>
    %23 = stablehlo.transpose %1#1, dims = [2, 1, 0] : (tensor<4x12x3xf32>) -> tensor<3x12x4xf32>
    return %22, %23 : tensor<3x12x12xf32>, tensor<3x12x4xf32>
  }
}
module {
  func.func private @"*_broadcast_scalar"(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"*_broadcast_scalar1"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func @main(%arg0: tensor<3x12x4xf32>) -> (tensor<3x12x12xf32>, tensor<3x12x4xf32>) {
    %cst = stablehlo.constant dense<2.000000e+00> : tensor<12x12x3xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x12x4xf32>) -> tensor<4x12x3xf32>
    %1:2 = enzyme.batch @"*_broadcast_scalar"(%0) {batch_shape = array<i64: 4, 12, 3>} : (tensor<4x12x3xf32>) -> (tensor<4x12x3xf32>, tensor<4x12x3xf32>)
    %2 = enzyme.batch @identity_broadcast_scalar(%1#0) {batch_shape = array<i64: 4, 12, 3>} : (tensor<4x12x3xf32>) -> tensor<4x12x3xf32>
    %3 = stablehlo.reduce(%2 init: %cst_0) applies stablehlo.add across dimensions = [0] : (tensor<4x12x3xf32>, tensor<f32>) -> tensor<12x3xf32>
    %4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<12x3xf32>) -> tensor<3x12xf32>
    %5 = stablehlo.reshape %4 : (tensor<3x12xf32>) -> tensor<3x12x1xf32>
    %6 = stablehlo.transpose %5, dims = [2, 1, 0] : (tensor<3x12x1xf32>) -> tensor<1x12x3xf32>
    %7 = stablehlo.dot_general %1#1, %1#1, batching_dims = [2] x [2], contracting_dims = [0] x [0] : (tensor<4x12x3xf32>, tensor<4x12x3xf32>) -> tensor<3x12x12xf32>
    %8 = stablehlo.transpose %7, dims = [1, 2, 0] : (tensor<3x12x12xf32>) -> tensor<12x12x3xf32>
    %9:3 = enzyme.batch @"*_broadcast_scalar1"(%cst, %8) {batch_shape = array<i64: 12, 12, 3>} : (tensor<12x12x3xf32>, tensor<12x12x3xf32>) -> (tensor<12x12x3xf32>, tensor<12x12x3xf32>, tensor<12x12x3xf32>)
    %10 = stablehlo.broadcast_in_dim %6, dims = [0, 1, 2] : (tensor<1x12x3xf32>) -> tensor<12x12x3xf32>
    %11:3 = enzyme.batch @"+_broadcast_scalar"(%10, %9#0) {batch_shape = array<i64: 12, 12, 3>} : (tensor<12x12x3xf32>, tensor<12x12x3xf32>) -> (tensor<12x12x3xf32>, tensor<12x12x3xf32>, tensor<12x12x3xf32>)
    %12 = stablehlo.transpose %11#0, dims = [2, 1, 0] : (tensor<12x12x3xf32>) -> tensor<3x12x12xf32>
    %13 = stablehlo.transpose %1#1, dims = [2, 1, 0] : (tensor<4x12x3xf32>) -> tensor<3x12x4xf32>
    return %12, %13 : tensor<3x12x12xf32>, tensor<3x12x4xf32>
  }
}

running the batch pass

envs/crash.mlir:10:10: error: Mismatched dimension sizes 12 and 3 in dimension 0
    %1 = stablehlo.multiply %arg0, %arg1 : tensor<f32>
         ^
envs/crash.mlir:29:12: note: called from
    %9:3 = enzyme.batch @"*_broadcast_scalar1"(%c, %8) {batch_shape = array<i64: 12, 12, 3>} : (tensor<12x12x3xf32>, tensor<12x12x3xf32>) -> (tensor<12x12x3xf32>, tensor<12x12x3xf32>, tensor<12x12x3xf32>)
           ^
envs/crash.mlir:10:10: remark: location of op
    %1 = stablehlo.multiply %arg0, %arg1 : tensor<f32>
         ^
envs/crash.mlir:29:12: note: called from
    %9:3 = enzyme.batch @"*_broadcast_scalar1"(%c, %8) {batch_shape = array<i64: 12, 12, 3>} : (tensor<12x12x3xf32>, tensor<12x12x3xf32>) -> (tensor<12x12x3xf32>, tensor<12x12x3xf32>, tensor<12x12x3xf32>)
           ^
LLVM ERROR: Failed to infer result type(s):
"stablehlo.multiply"(...) {} : (tensor<12x3x12xf32>, tensor<3x12x12xf32>) -> ( ??? )
@avik-pal
Copy link
Collaborator Author

module {
  func.func @main(%arg0: tensor<2x12x4xf32>) -> (tensor<2x12x12xf32>, tensor<2x12x4xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x12x4xf32>) -> tensor<4x12x2xf32>
    %cst = stablehlo.constant dense<2.000000e+00> : tensor<12x12x2xf32>
    %1 = stablehlo.transpose %0, dims = [1, 0, 2] : (tensor<4x12x2xf32>) -> tensor<12x4x2xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x2xf32>
    %2 = stablehlo.transpose %1, dims = [2, 0, 1] : (tensor<12x4x2xf32>) -> tensor<2x12x4xf32>
    %3 = stablehlo.transpose %0, dims = [2, 0, 1] : (tensor<4x12x2xf32>) -> tensor<2x4x12xf32>
    %4 = stablehlo.convert %2 : tensor<2x12x4xf32>
    %5 = stablehlo.convert %3 : tensor<2x4x12xf32>
    %6 = stablehlo.dot_general %4, %5, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x12x4xf32>, tensor<2x4x12xf32>) -> tensor<2x12x12xf32>
    %7 = stablehlo.transpose %6, dims = [1, 2, 0] : (tensor<2x12x12xf32>) -> tensor<12x12x2xf32>
    %8 = stablehlo.multiply %cst, %7 : tensor<12x12x2xf32>
    %9 = stablehlo.transpose %8, dims = [2, 1, 0] : (tensor<12x12x2xf32>) -> tensor<2x12x12xf32>
    %10 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<4x12x2xf32>) -> tensor<2x12x4xf32>
    return %9, %10 : tensor<2x12x12xf32>, tensor<2x12x4xf32>
  }
}

This is enough to cause the crash

@wsmoses
Copy link
Member

wsmoses commented Jan 17, 2025

I just updated per @jumerckx PR for batching. Does that resolve? .... actually no I suppose not since the latter case doesn't have a call.

I think the issue here is that dot_general needs a custom batch interface impl like transpose has

@avik-pal
Copy link
Collaborator Author

This failure seems to be very specific to dot_general followed by a multiply. If I replace the 2 * (x * x) with a y = x * x; y + y (i.e. replace the multiply with an add) it no longer crashes

@avik-pal
Copy link
Collaborator Author

avik-pal commented Feb 8, 2025

(gdb) bt
#0 __pthread_kill_implementation (no_tid=0, signo=6, threadid=) at ./nptl/pthread_kill.c:44
#1 __pthread_kill_internal (signo=6, threadid=) at ./nptl/pthread_kill.c:78
#2 __GI___pthread_kill (threadid=, signo=signo@entry=6) at ./nptl/pthread_kill.c:89
#3 0x00007ffff784527e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4 0x00007ffff78288ff in __GI_abort () at ./stdlib/abort.c:79
#5 0x0000555555e53775 in llvm::report_fatal_error(llvm::Twine const&, bool) [clone .cold] ()
#6 0x0000555559a58329 in llvm::report_fatal_error(llvm::StringRef, bool) ()
#7 0x0000555557b2d0b2 in mlir::detail::reportFatalInferReturnTypesError(mlir::OperationState&) ()
#8 0x00005555570ae5be in mlir::stablehlo::AndOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Value, mlir::Value) ()
#9 0x000055555617031d in llvm::LogicalResult (anonymous namespace)::simplifyBinaryOpWithTransposemlir::stablehlo::MulOp(mlir::stablehlo::MulOp, mlir::PatternRewriter&) ()
#10 0x00005555579590ec in mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::{lambda()#1}::operator()() const ()
#11 0x0000555557959be0 in mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) ()
#12 0x0000555557951958 in (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist() ()
#13 0x00005555579527b9 in mlir::applyPatternsGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) ()
#14 0x0000555556151460 in (anonymous namespace)::EnzymeHLOOptPass::runOnOperation() ()
#15 0x0000555557b569df in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) ()
#16 0x0000555557b56eac in mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) ()
#17 0x0000555557b57e5f in mlir::PassManager::run(mlir::Operation*) ()
#18 0x0000555555f40b1f in performActions(llvm::raw_ostream&, std::shared_ptrllvm::SourceMgr const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) ()
#19 0x0000555555f413e6 in processBuffer(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_deletellvm::MemoryBuffer >, mlir::MlirOptMainConfig const&, mlir::DialectRegistry&, llvm::ThreadPoolInterface*) ()
#20 0x0000555555f41575 in llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_deletellvm::MemoryBuffer >, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_deletellvm::MemoryBuffer >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::{lambda(std::unique_ptr<llvm::MemoryBuffer, std::default_deletellvm::MemoryBuffer >, llvm::raw_ostream&)#1}>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_deletellvm::MemoryBuffer >, llvm::raw_ostream&) ()
#21 0x0000555557ca198c in mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_deletellvm::MemoryBuffer >, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_deletellvm::MemoryBuffer >, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) ()
#22 0x0000555555f38c5a in mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_deletellvm::MemoryBuffer >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) ()
#23 0x0000555555f41734 in mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) ()
#24 0x0000555555f41c8e in mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) ()
#25 0x0000555555e53f9a in main ()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants