-
Notifications
You must be signed in to change notification settings - Fork 9
[DRAFT][s4xbf16] JoinOp vectorization patch #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
For small-M shapes, for best perf, we'll additionally want XLA to swap A/B so that the LHS of dot is quantized, and then set envvar DISABLE_MMA_V3 to force Ampere-MMA. |
loislo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this cl needs a rebase
|
@loislo, by rebase, do you mean to rebase on openxla/triton's llvm-head, or OAI's triton's main? |
…leaveTMem.cpp (triton-lang#7924) `TritonNvidiaGPU/interleave_tmem.mlir` fails under address sanitizer. The `ConstantIntOp` operations were created without attachment to any block in http://github.com/triton-lang/triton/pull/7622, which caused a memory leak. This change addresses the problem by adding an insertion point. <details open> <summary>Full log</summary> ================================================================= ==3831==ERROR: LeakSanitizer: detected memory leaks Direct leak of 576 byte(s) in 6 object(s) allocated from: #0 0x55c3eca39164 in malloc [third_party/llvm/llvm-project/compiler-rt/lib/asan/asan_malloc_linux.cpp:67](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/compiler-rt/lib/asan/asan_malloc_linux.cpp?l=67&ws=tap-presubmit-server/421956858&snapshot=2):3 #1 0x55c3f176afb3 in mlir::Operation::create(mlir::Location, mlir::OperationName, mlir::TypeRange, mlir::ValueRange, mlir::DictionaryAttr, mlir::OpaqueProperties, mlir::BlockRange, unsigned int) [third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:113](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=113&ws=tap-presubmit-server/421956858&snapshot=2):46 #2 0x55c3f176a90c in create [third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:74](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=74&ws=tap-presubmit-server/421956858&snapshot=2):10 #3 0x55c3f176a90c in mlir::Operation::create(mlir::Location, mlir::OperationName, mlir::TypeRange, mlir::ValueRange, mlir::NamedAttrList&&, mlir::OpaqueProperties, mlir::BlockRange, mlir::RegionRange) [third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:57](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=57&ws=tap-presubmit-server/421956858&snapshot=2):7 #4 0x55c3f176a61b in mlir::Operation::create(mlir::OperationState const&) [third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp:35](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Operation.cpp?l=35&ws=tap-presubmit-server/421956858&snapshot=2):7 #5 0x55c3f1678a78 in mlir::OpBuilder::create(mlir::OperationState const&) [third_party/llvm/llvm-project/mlir/lib/IR/Builders.cpp:453](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/IR/Builders.cpp?l=453&ws=tap-presubmit-server/421956858&snapshot=2):17 #6 0x55c3ecf3668f in mlir::arith::ConstantIntOp mlir::OpBuilder::create<mlir::arith::ConstantIntOp, int, int>(mlir::Location, int&&, int&&) [third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h:507](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h?l=507&ws=tap-presubmit-server/421956858&snapshot=2):16 #7 0x55c3eefa690a in findBufferAccessMemdescSubview [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:75](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=75&ws=tap-presubmit-server/421956858&snapshot=2):33 #8 0x55c3eefa690a in mlir::triton::nvidia_gpu::(anonymous namespace)::findBufferAccess(mlir::Value) [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:151](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=151&ws=tap-presubmit-server/421956858&snapshot=2):12 #9 0x55c3eefa70e7 in mlir::triton::nvidia_gpu::(anonymous namespace)::findBufferAccess(mlir::Value) [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:156](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=156&ws=tap-presubmit-server/421956858&snapshot=2):34 #10 0x55c3eefa4c0c in tmemMayAlias [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:173](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=173&ws=tap-presubmit-server/421956858&snapshot=2):28 #11 0x55c3eefa4c0c in sinkOps [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:227](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=227&ws=tap-presubmit-server/421956858&snapshot=2):36 #12 0x55c3eefa4c0c in trySinkOp [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:253](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=253&ws=tap-presubmit-server/421956858&snapshot=2):10 #13 0x55c3eefa4c0c in mlir::triton::nvidia_gpu::TritonNvidiaGPUInterleaveTMemPass::runOnOperation() [third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp:275](https://cs.corp.google.com/piper///depot/google3/third_party/triton/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp?l=275&ws=tap-presubmit-server/421956858&snapshot=2):14 #14 0x55c3f1560ad1 in operator() [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:553](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=553&ws=tap-presubmit-server/421956858&snapshot=2):17 #15 0x55c3f1560ad1 in void llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) [third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=46&ws=tap-presubmit-server/421956858&snapshot=2):12 #16 0x55c3f1559920 in operator() [third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=69&ws=tap-presubmit-server/421956858&snapshot=2):12 #17 0x55c3f1559920 in executeAction<mlir::PassExecutionAction, mlir::Pass &> [third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h?l=280&ws=tap-presubmit-server/421956858&snapshot=2):7 #18 0x55c3f1559920 in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:547](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=547&ws=tap-presubmit-server/421956858&snapshot=2):21 #19 0x55c3f155d46f in runPipeline [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:619](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=619&ws=tap-presubmit-server/421956858&snapshot=2):16 #20 0x55c3f155d46f in mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:933](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=933&ws=tap-presubmit-server/421956858&snapshot=2):10 #21 0x55c3f155d15b in mlir::PassManager::run(mlir::Operation*) [third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp:913](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Pass/Pass.cpp?l=913&ws=tap-presubmit-server/421956858&snapshot=2):60 #22 0x55c3ed0a8b20 in performActions(llvm::raw_ostream&, std::__u::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:477](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=477&ws=tap-presubmit-server/421956858&snapshot=2):17 #23 0x55c3ed0a8363 in processBuffer [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:553](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=553&ws=tap-presubmit-server/421956858&snapshot=2):12 #24 0x55c3ed0a8363 in operator() [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:642](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=642&ws=tap-presubmit-server/421956858&snapshot=2):12 #25 0x55c3ed0a8363 in llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::MemoryBufferRef const&, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::MemoryBufferRef const&, llvm::raw_ostream&) [third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=46&ws=tap-presubmit-server/421956858&snapshot=2):12 triton-lang#26 0x55c3f17bd34f in operator() [third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h?l=69&ws=tap-presubmit-server/421956858&snapshot=2):12 triton-lang#27 0x55c3f17bd34f in mlir::splitAndProcessBuffer(std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::MemoryBufferRef const&, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) [third_party/llvm/llvm-project/mlir/lib/Support/ToolUtilities.cpp:30](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Support/ToolUtilities.cpp?l=30&ws=tap-presubmit-server/421956858&snapshot=2):12 triton-lang#28 0x55c3ed09d0c6 in mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:647](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=647&ws=tap-presubmit-server/421956858&snapshot=2):26 triton-lang#29 0x55c3ed09d67f in mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:693](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=693&ws=tap-presubmit-server/421956858&snapshot=2):14 triton-lang#30 0x55c3ed09dc59 in mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) [third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:709](https://cs.corp.google.com/piper///depot/google3/third_party/llvm/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp?l=709&ws=tap-presubmit-server/421956858&snapshot=2):10 triton-lang#31 0x55c3eca74a70 in main [third_party/triton/bin/triton-opt.cpp:14](https://cs.corp.google.com/piper///depot/google3/third_party/triton/bin/triton-opt.cpp?l=14&ws=tap-presubmit-server/421956858&snapshot=2):33 triton-lang#32 0x7f1fd58613d3 in __libc_start_main (/usr/grte/v5/lib64/libc.so.6+0x613d3) (BuildId: 9a996398ce14a94560b0c642eb4f6e94) triton-lang#33 0x55c3ec995aa9 in _start /usr/grte/v5/debug-src/src/csu/../sysdeps/x86_64/start.S:120 </details> --------- Co-authored-by: Thomas Raoux <[email protected]>
Do not merge until rebased on upstream Triton up to triton-lang@c1ed673
This is 1 of the 2 patches needed to improve int4xbf16 GEMM perf.
This is needed because joinOp by default interleaves every element of the two input matrices. In the case of bf16, this means Triton will extract the 2x bf16 values out of the 32-bit register and re-insert them into a new register. This results in many
movinstructions before MMA. On certain shapes, this could mean a ~10% perf penalty.This PR addresses the above by situationally "vectorizing" the interleaving; namely, join every two elements instead of one. This avoids the need to extract values out of registers. Of course, this would also require one to modify the inline_asm logic before the join to produce the correct layout.
cc @gflegar