diff --git a/.bazelrc b/.bazelrc index 0074cf53fe..fbc2f1e8e3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -24,6 +24,7 @@ common common --define framework_shared_object=true common --define tsl_protobuf_header_only=true common --define=allow_oversize_protos=true +common --check_visibility=false # Some targets have the same py source file, but use different diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 975f166f9a..5f175c0c62 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -892,17 +892,23 @@ cc_library( "//src/external/isl:Isl", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@enzyme//:EnzymeMLIR", "@jax//jaxlib/gpu:triton_cc_proto", "@llvm-project//llvm:Core", "@llvm-project//llvm:ExecutionEngine", "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:OrcTargetProcess", "@llvm-project//llvm:Passes", "@llvm-project//llvm:Scalar", "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TargetParser", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", @@ -913,6 +919,7 @@ cc_library( "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:CommonFolders", @@ -992,10 +999,36 @@ cc_library( "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_passes", "@stablehlo//:stablehlo_type_inference", + "@triton//:GluonDialect", "@triton//:TritonDialects", + "@triton//:TritonGPUToLLVM", + "@triton//:TritonGPUTransforms", + "@triton//:TritonLLVMIR", + "@triton//:TritonNvidiaGPUTransforms", + "@triton//:TritonToTritonGPU", "@triton//:TritonToTritonGPUPasses", + "@triton//:TritonTransforms", + "@triton//:WarpSpecialization", + "@triton//:triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@triton//:triton_nvidia_gpu_transforms_inc_gen", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:NVHopperTransforms", + "@triton//third_party/nvidia:NVWSDialect", + "@triton//third_party/nvidia:NVWSTransforms", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + "@triton//third_party/proton:ProtonGPUIR", + "@triton//third_party/proton:ProtonIR", + "@xla//xla/backends/gpu/codegen/triton:compilation_pipeline", "@xla//xla/mlir/utils:type_util", "@xla//xla/mlir_hlo", + "@xla//xla/pjrt:triton", + "@xla//xla/stream_executor:device_description", + "@xla//xla/stream_executor/cuda:cuda_compute_capability", + "@xla//xla/stream_executor/rocm:rocm_compute_capability", "@zlib", ], ) @@ -1095,6 +1128,7 @@ cc_library( # Triton "@triton//:TritonDialects", + "@triton//:GluonDialect", "@triton//:TritonGPUToLLVM", "@triton//:TritonGPUTransforms", "@triton//:TritonLLVMIR", @@ -1114,6 +1148,8 @@ cc_library( "@triton//third_party/nvidia:NVWSDialect", "@triton//third_party/nvidia:NVWSTransforms", "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + "@triton//third_party/proton:ProtonIR", + "@triton//third_party/proton:ProtonGPUIR", # Shardy stuff "@shardy//shardy/dialect/sdy/ir:dialect", diff --git a/src/enzyme_ad/jax/Passes/LowerTriton.cpp b/src/enzyme_ad/jax/Passes/LowerTriton.cpp new file mode 100644 index 0000000000..8a7664a685 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerTriton.cpp @@ -0,0 +1,202 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "xla/backends/gpu/codegen/triton/compilation_pipeline.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/rocm/rocm_compute_capability.h" + +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "nvidia/include/Dialect/NVWS/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Gluon/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "xla/pjrt/triton.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "lower-triton" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERTRITONPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +void collectTritonKernels( + DenseMap &tritonKernels, + SymbolTableCollection &symbolTable, triton_ext::TritonCallOp op) { + auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); + if (!funcOp) { + op->emitError() << "Failed to find function '" << op.getFn() << "' in " + << "module"; + return; + } + + auto wrappedMod = funcOp->getParentOfType(); + if (!wrappedMod) { + op->emitError() << "Failed to find parent built-in module."; + return; + } + + auto ttModOP = wrappedMod->getParentOfType(); + if (!ttModOP) { + op->emitError() << "No `triton_ext.module` found!"; + return; + } + + tritonKernels[op] = wrappedMod; + return; +} + +struct LowerTritonPass + : public mlir::enzyme::impl::LowerTritonPassBase { + using Base::Base; + + void runOnOperation() override { + auto modOp = getOperation(); + + stream_executor::GpuComputeCapability gpuCC; + if (backend == "cuda") { + auto cudaCC = + stream_executor::CudaComputeCapability::FromString(computeCapability); + if (!cudaCC.ok()) { + modOp->emitError("Unsupported cuda compute capability: ") + << cudaCC.status().ToString(); + return; + } + gpuCC = stream_executor::GpuComputeCapability(cudaCC.value()); + } else if (backend == "rocm") { + auto rocmCC = stream_executor::RocmComputeCapability(computeCapability); + gpuCC = stream_executor::GpuComputeCapability(rocmCC); + } else { + modOp->emitError("Unsupported backend: ") << backend; + return; + } + + DenseMap tritonKernels; + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(modOp); + modOp->walk([&](triton_ext::TritonCallOp op) { + collectTritonKernels(tritonKernels, symbolTable, op); + }); + + SmallVector clusterInfos; + + OpBuilder builder(modOp); + + bool anyFailed = false; + for (auto [ttCallOp, innerMod] : tritonKernels) { + int32_t numWarps = 4; + if (innerMod->hasAttrOfType("enzymexla.num_warps")) { + numWarps = innerMod->getAttrOfType("enzymexla.num_warps") + .getInt(); + } + int32_t numCtas = 1; + if (innerMod->hasAttrOfType("enzymexla.num_ctas")) { + numCtas = + innerMod->getAttrOfType("enzymexla.num_ctas").getInt(); + } + int32_t numStages = 3; + if (innerMod->hasAttrOfType("enzymexla.num_stages")) { + numStages = innerMod->getAttrOfType("enzymexla.num_stages") + .getInt(); + } + + OpPassManager pm; + + xla::gpu::CreateTritonXlaPipeline(&pm, gpuCC, true, true, numStages); + mlir::triton::nvidia_gpu::ClusterInfo clusterInfo; + xla::gpu::CreateTritonPipeline(&pm, gpuCC, numWarps, numCtas, numStages, + clusterInfo); + clusterInfos.push_back(clusterInfo); + + if (failed(runPipeline(pm, innerMod))) { + innerMod->emitError( + "Failed to lower Triton kernel to TritonGPU kernel"); + anyFailed = true; + continue; + } + + int32_t threadsPerWarp = 32; + if (innerMod->hasAttrOfType("ttg.threads_per_warp")) { + threadsPerWarp = + innerMod->getAttrOfType("ttg.threads_per_warp") + .getInt(); + } + + builder.setInsertionPoint(ttCallOp); + + auto sharedMemSizeAttr = + innerMod->getAttrOfType("ttg.shared"); + auto sharedMemSize = sharedMemSizeAttr.getInt(); + auto shmemOpType = ttCallOp.getGridx().getType(); + auto shmemOp = stablehlo::ConstantOp::create( + builder, ttCallOp.getLoc(), shmemOpType, + cast(makeAttr(shmemOpType, sharedMemSize))); + + auto blockX = stablehlo::ConstantOp::create( + builder, ttCallOp.getLoc(), shmemOpType, + cast(makeAttr(shmemOpType, threadsPerWarp * numWarps))); + auto blockYZ = stablehlo::ConstantOp::create( + builder, ttCallOp.getLoc(), shmemOpType, + cast(makeAttr(shmemOpType, 1))); + + SmallVector newInputs(ttCallOp.getInputs().begin(), + ttCallOp.getInputs().end()); + // we don't use the next 2 inputs + auto scratchSpace = stablehlo::ConstantOp::create( + builder, ttCallOp.getLoc(), + RankedTensorType::get({}, builder.getI8Type()), + cast( + makeAttr(RankedTensorType::get({}, builder.getI8Type()), 0))); + newInputs.push_back(scratchSpace); + newInputs.push_back(scratchSpace); + + auto kernelCallOp = enzymexla::KernelCallOp::create( + builder, ttCallOp.getLoc(), ttCallOp.getResultTypes(), + ttCallOp.getFn(), ttCallOp.getGridx(), ttCallOp.getGridy(), + ttCallOp.getGridz(), blockX, blockYZ, blockYZ, shmemOp, + ttCallOp.getClusterx(), ttCallOp.getClustery(), + ttCallOp.getClusterz(), newInputs, ttCallOp.getBackendConfigAttr(), + ttCallOp.getOperandLayoutsAttr(), ttCallOp.getResultLayoutsAttr(), + ttCallOp.getArgAttrsAttr(), ttCallOp.getResAttrsAttr(), + ttCallOp.getOutputOperandAliasesAttr(), + ttCallOp.getXlaSideEffectFreeAttr()); + ttCallOp.replaceAllUsesWith(kernelCallOp); + ttCallOp.erase(); + } + + if (anyFailed) { + signalPassFailure(); + return; + } + } +}; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 5bf0092012..52f3fec327 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1069,6 +1069,47 @@ def ConvertTritonToTritonGPUPreservingModuleAttributesPass : Pass< >]; } +def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> { + let summary = "Lower Triton to kernel call"; + let dependentDialects = [ + "triton::TritonDialect", + "enzymexla::EnzymeXLADialect", + "func::FuncDialect", + "enzymexla::triton_ext::TritonExtDialect", + "stablehlo::StablehloDialect", + "triton::nvidia_gpu::TritonNvidiaGPUDialect", + "triton::instrument::TritonInstrumentDialect", + "triton::nvgpu::NVGPUDialect", + "triton::nvws::NVWSDialect", + "triton::amdgpu::TritonAMDGPUDialect", + "triton::proton::ProtonDialect", + "triton::proton::gpu::ProtonGPUDialect", + "triton::gluon::GluonDialect", + "triton::gpu::TritonGPUDialect", + "cf::ControlFlowDialect", + "math::MathDialect", + "arith::ArithDialect", + "scf::SCFDialect", + "gpu::GPUDialect", + "LLVM::LLVMDialect", + "NVVM::NVVMDialect", + "ROCDL::ROCDLDialect", + ]; + let options = [ + Option< + /*C++ variable name=*/"backend", + /*CLI argument=*/"backend", + /*type=*/"std::string", + /*default=*/"\"cuda\"", + /*description=*/"HW backend">, + Option< + /*C++ variable name=*/"computeCapability", + /*CLI argument=*/"compute_capability", + /*type=*/"std::string", + /*default=*/"\"8.0\"", + /*description=*/"Compute capability">]; +} + def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> { let summary = "Legalize batching specific enzyme ops to stablehlo dialect"; let dependentDialects = [ diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 46ce854a3b..81c8f25363 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -105,14 +105,21 @@ #include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_export.h" #include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h" +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "nvidia/include/Dialect/NVWS/IR/Dialect.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" +#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Gluon/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" #include "triton/Target/LLVMIR/Passes.h" @@ -218,6 +225,13 @@ void registerDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); registry.insert(); } @@ -256,6 +270,13 @@ void loadAllRegisteredDialects(mlir::MLIRContext &context) { context.loadDialect(); context.loadDialect(); context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); context.loadDialect(); } diff --git a/test/lit_tests/triton/add_kernel.mlir b/test/lit_tests/triton/add_kernel.mlir index 8472a09ce2..b8c9943931 100644 --- a/test/lit_tests/triton/add_kernel.mlir +++ b/test/lit_tests/triton/add_kernel.mlir @@ -1,4 +1,5 @@ // RUN: enzymexlamlir-opt %s -canonicalize | FileCheck %s +// RUN: enzymexlamlir-opt %s -lower-triton | FileCheck %s --check-prefix=LOWER module { // CHECK: enzymexla_tt_ext.module