From 40c9af7f6b4846407e4e19304f82f924a076fa8e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Nov 2025 10:31:10 -0600 Subject: [PATCH 1/4] feat: lower triton via xla --- .bazelrc | 1 + src/enzyme_ad/jax/BUILD | 28 +++++ src/enzyme_ad/jax/Passes/LowerTriton.cpp | 124 +++++++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.td | 41 ++++++++ src/enzyme_ad/jax/RegistryUtils.cpp | 21 ++++ test/lit_tests/triton/add_kernel.mlir | 1 + 6 files changed, 216 insertions(+) create mode 100644 src/enzyme_ad/jax/Passes/LowerTriton.cpp diff --git a/.bazelrc b/.bazelrc index 0074cf53fe..fbc2f1e8e3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -24,6 +24,7 @@ common common --define framework_shared_object=true common --define tsl_protobuf_header_only=true common --define=allow_oversize_protos=true +common --check_visibility=false # Some targets have the same py source file, but use different diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 975f166f9a..e20878f844 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -992,10 +992,35 @@ cc_library( "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_passes", "@stablehlo//:stablehlo_type_inference", + "@triton//:GluonDialect", "@triton//:TritonDialects", + "@triton//:TritonGPUToLLVM", + "@triton//:TritonGPUTransforms", + "@triton//:TritonLLVMIR", + "@triton//:TritonNvidiaGPUTransforms", + "@triton//:TritonToTritonGPU", "@triton//:TritonToTritonGPUPasses", + "@triton//:TritonTransforms", + "@triton//:WarpSpecialization", + "@triton//:triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@triton//:triton_nvidia_gpu_transforms_inc_gen", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:NVHopperTransforms", + "@triton//third_party/nvidia:NVWSDialect", + "@triton//third_party/nvidia:NVWSTransforms", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + "@triton//third_party/proton:ProtonGPUIR", + "@triton//third_party/proton:ProtonIR", + "@xla//xla/backends/gpu/codegen/triton:compilation_pipeline", "@xla//xla/mlir/utils:type_util", "@xla//xla/mlir_hlo", + "@xla//xla/stream_executor:device_description", + "@xla//xla/stream_executor/cuda:cuda_compute_capability", + "@xla//xla/stream_executor/rocm:rocm_compute_capability", "@zlib", ], ) @@ -1095,6 +1120,7 @@ cc_library( # Triton "@triton//:TritonDialects", + "@triton//:GluonDialect", "@triton//:TritonGPUToLLVM", "@triton//:TritonGPUTransforms", "@triton//:TritonLLVMIR", @@ -1114,6 +1140,8 @@ cc_library( "@triton//third_party/nvidia:NVWSDialect", "@triton//third_party/nvidia:NVWSTransforms", "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + "@triton//third_party/proton:ProtonIR", + "@triton//third_party/proton:ProtonGPUIR", # Shardy stuff "@shardy//shardy/dialect/sdy/ir:dialect", diff --git a/src/enzyme_ad/jax/Passes/LowerTriton.cpp b/src/enzyme_ad/jax/Passes/LowerTriton.cpp new file mode 100644 index 0000000000..5fe2f87b36 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerTriton.cpp @@ -0,0 +1,124 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "xla/backends/gpu/codegen/triton/compilation_pipeline.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/rocm/rocm_compute_capability.h" + +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "nvidia/include/Dialect/NVWS/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Gluon/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "lower-triton" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERTRITONPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +void collectTritonKernels(SmallVectorImpl &tritonKernels, + SymbolTableCollection &symbolTable, + triton_ext::TritonCallOp op) { + auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); + if (!funcOp) { + op->emitError() << "Failed to find function '" << op.getFn() << "' in " + << "module"; + return; + } + + auto wrappedMod = funcOp->getParentOfType(); + if (!wrappedMod) { + op->emitError() << "Failed to find parent built-in module."; + return; + } + + auto ttModOP = wrappedMod->getParentOfType(); + if (!ttModOP) { + op->emitError() << "No `triton_ext.module` found!"; + return; + } + + tritonKernels.push_back(wrappedMod); + return; +} + +struct LowerTritonPass + : public mlir::enzyme::impl::LowerTritonPassBase { + using Base::Base; + + void runOnOperation() override { + auto modOp = getOperation(); + + stream_executor::GpuComputeCapability gpuCC; + if (backend == "cuda") { + auto cudaCC = + stream_executor::CudaComputeCapability::FromString(computeCapability); + if (!cudaCC.ok()) { + modOp->emitError("Unsupported cuda compute capability: ") + << cudaCC.status().ToString(); + return; + } + gpuCC = stream_executor::GpuComputeCapability(cudaCC.value()); + } else if (backend == "rocm") { + auto rocmCC = stream_executor::RocmComputeCapability(computeCapability); + gpuCC = stream_executor::GpuComputeCapability(rocmCC); + } else { + modOp->emitError("Unsupported backend: ") << backend; + return; + } + + SmallVector tritonKernels; + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(modOp); + modOp->walk([&](triton_ext::TritonCallOp op) { + collectTritonKernels(tritonKernels, symbolTable, op); + }); + + OpPassManager pm; + + // TODO: bool rewrite_int4, bool allow_tma, int num_stages + xla::gpu::CreateTritonXlaPipeline(&pm, gpuCC, false, true, 1); + + mlir::triton::nvidia_gpu::ClusterInfo out_cluster_info; + // TODO: int num_warps, int num_ctas, int num_stages + xla::gpu::CreateTritonPipeline(&pm, gpuCC, 4, 1, 1, out_cluster_info); + + for (auto tritonMod : tritonKernels) { + if (failed(runPipeline(pm, tritonMod))) { + tritonMod->emitError( + "Failed to lower Triton kernel to TritonGPU kernel"); + signalPassFailure(); + return; + } + } + } +}; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 5bf0092012..52f3fec327 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1069,6 +1069,47 @@ def ConvertTritonToTritonGPUPreservingModuleAttributesPass : Pass< >]; } +def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> { + let summary = "Lower Triton to kernel call"; + let dependentDialects = [ + "triton::TritonDialect", + "enzymexla::EnzymeXLADialect", + "func::FuncDialect", + "enzymexla::triton_ext::TritonExtDialect", + "stablehlo::StablehloDialect", + "triton::nvidia_gpu::TritonNvidiaGPUDialect", + "triton::instrument::TritonInstrumentDialect", + "triton::nvgpu::NVGPUDialect", + "triton::nvws::NVWSDialect", + "triton::amdgpu::TritonAMDGPUDialect", + "triton::proton::ProtonDialect", + "triton::proton::gpu::ProtonGPUDialect", + "triton::gluon::GluonDialect", + "triton::gpu::TritonGPUDialect", + "cf::ControlFlowDialect", + "math::MathDialect", + "arith::ArithDialect", + "scf::SCFDialect", + "gpu::GPUDialect", + "LLVM::LLVMDialect", + "NVVM::NVVMDialect", + "ROCDL::ROCDLDialect", + ]; + let options = [ + Option< + /*C++ variable name=*/"backend", + /*CLI argument=*/"backend", + /*type=*/"std::string", + /*default=*/"\"cuda\"", + /*description=*/"HW backend">, + Option< + /*C++ variable name=*/"computeCapability", + /*CLI argument=*/"compute_capability", + /*type=*/"std::string", + /*default=*/"\"8.0\"", + /*description=*/"Compute capability">]; +} + def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> { let summary = "Legalize batching specific enzyme ops to stablehlo dialect"; let dependentDialects = [ diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 46ce854a3b..81c8f25363 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -105,14 +105,21 @@ #include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_export.h" #include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h" +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "nvidia/include/Dialect/NVWS/IR/Dialect.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" +#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Gluon/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" #include "triton/Target/LLVMIR/Passes.h" @@ -218,6 +225,13 @@ void registerDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); registry.insert(); } @@ -256,6 +270,13 @@ void loadAllRegisteredDialects(mlir::MLIRContext &context) { context.loadDialect(); context.loadDialect(); context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); context.loadDialect(); } diff --git a/test/lit_tests/triton/add_kernel.mlir b/test/lit_tests/triton/add_kernel.mlir index 8472a09ce2..b8c9943931 100644 --- a/test/lit_tests/triton/add_kernel.mlir +++ b/test/lit_tests/triton/add_kernel.mlir @@ -1,4 +1,5 @@ // RUN: enzymexlamlir-opt %s -canonicalize | FileCheck %s +// RUN: enzymexlamlir-opt %s -lower-triton | FileCheck %s --check-prefix=LOWER module { // CHECK: enzymexla_tt_ext.module From 7b9aebecbf86fd17de6fcedbd2626ea7736e2328 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Nov 2025 15:52:04 -0600 Subject: [PATCH 2/4] feat: lower LLVM IR to ptx --- src/enzyme_ad/jax/BUILD | 14 ++ src/enzyme_ad/jax/Passes/LowerTriton.cpp | 263 +++++++++++++++++++++-- src/enzyme_ad/jax/Passes/Passes.td | 6 + 3 files changed, 267 insertions(+), 16 deletions(-) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index e20878f844..2fe4f4d2d6 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -892,17 +892,23 @@ cc_library( "//src/external/isl:Isl", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@enzyme//:EnzymeMLIR", "@jax//jaxlib/gpu:triton_cc_proto", "@llvm-project//llvm:Core", "@llvm-project//llvm:ExecutionEngine", "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:OrcTargetProcess", "@llvm-project//llvm:Passes", "@llvm-project//llvm:Scalar", "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TargetParser", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", @@ -913,6 +919,7 @@ cc_library( "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:CommonFolders", @@ -922,6 +929,8 @@ cc_library( "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:ControlFlowToSCF", "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", "@llvm-project//mlir:FromLLVMIRTranslation", "@llvm-project//mlir:FromLLVMIRTranslationRegistration", "@llvm-project//mlir:FuncDialect", @@ -940,6 +949,7 @@ cc_library( "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", @@ -952,7 +962,9 @@ cc_library( "@llvm-project//mlir:NVGPUDialect", "@llvm-project//mlir:NVGPUToNVVM", "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVVMTarget", "@llvm-project//mlir:NVVMToLLVM", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", "@llvm-project//mlir:OpenMPDialect", "@llvm-project//mlir:OpenMPToLLVM", "@llvm-project//mlir:Parser", @@ -967,6 +979,7 @@ cc_library( "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TargetLLVM", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:ToLLVMIRTranslationRegistration", @@ -1018,6 +1031,7 @@ cc_library( "@xla//xla/backends/gpu/codegen/triton:compilation_pipeline", "@xla//xla/mlir/utils:type_util", "@xla//xla/mlir_hlo", + "@xla//xla/pjrt:triton", "@xla//xla/stream_executor:device_description", "@xla//xla/stream_executor/cuda:cuda_compute_capability", "@xla//xla/stream_executor/rocm:rocm_compute_capability", diff --git a/src/enzyme_ad/jax/Passes/LowerTriton.cpp b/src/enzyme_ad/jax/Passes/LowerTriton.cpp index 5fe2f87b36..37518f0a9b 100644 --- a/src/enzyme_ad/jax/Passes/LowerTriton.cpp +++ b/src/enzyme_ad/jax/Passes/LowerTriton.cpp @@ -1,5 +1,7 @@ #include "src/enzyme_ad/jax/Passes/Passes.h" +#include + #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -14,6 +16,9 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" #include "nvidia/include/Dialect/NVWS/IR/Dialect.h" #include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" @@ -27,8 +32,32 @@ #include "triton/Dialect/TritonInstrument/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "xla/pjrt/triton.h" + #include "src/enzyme_ad/jax/Utils.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_replace.h" + +#include "mlir/ExecutionEngine/OptUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "llvm/TargetParser/Triple.h" + +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #define DEBUG_TYPE "lower-triton" @@ -45,9 +74,9 @@ using namespace mlir::enzyme; using namespace mlir::enzymexla; using namespace mlir::enzymexla::triton_ext; -void collectTritonKernels(SmallVectorImpl &tritonKernels, - SymbolTableCollection &symbolTable, - triton_ext::TritonCallOp op) { +void collectTritonKernels( + DenseMap &tritonKernels, + SymbolTableCollection &symbolTable, triton_ext::TritonCallOp op) { auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); if (!funcOp) { op->emitError() << "Failed to find function '" << op.getFn() << "' in " @@ -67,10 +96,138 @@ void collectTritonKernels(SmallVectorImpl &tritonKernels, return; } - tritonKernels.push_back(wrappedMod); + tritonKernels[op] = wrappedMod; return; } +namespace cuda { + +namespace fs = std::filesystem; + +absl::StatusOr> +CreateTargetMachine(llvm::Module *module, absl::string_view arch_name, + bool enable_fp_fusion, absl::string_view features) { + // Based on createTargetMachine() in triton/python/src/llvm.cc + std::string error; + const auto *target = + llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + if (target == nullptr) { + return absl::InternalError( + absl::StrFormat("Failed to lookup LLVM target based on triple %s: %s", + module->getTargetTriple().str(), error)); + } + llvm::TargetOptions opt; + if (enable_fp_fusion) { + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + } + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + opt.MCOptions.AsmVerbose = true; + opt.MCOptions.PreserveAsmComments = true; + return std::unique_ptr(target->createTargetMachine( + module->getTargetTriple(), arch_name, features, opt, llvm::Reloc::PIC_, + std::nullopt, llvm::CodeGenOptLevel::Aggressive)); +} + +absl::Status LinkLibdevice(llvm::Module *module, std::string libdevice_dir) { + auto libdevice_path = (fs::path(libdevice_dir) / "libdevice.10.bc").string(); + + llvm::LLVMContext &ctx = module->getContext(); + llvm::SMDiagnostic err; + std::unique_ptr libdevice_module = + llvm::parseIRFile(libdevice_path, err, ctx); + if (!libdevice_module) { + return absl::InternalError( + absl::StrFormat("Failed to parse libdevice IR file at %s: %s", + libdevice_path, err.getMessage())); + } + + llvm::Linker linker(*module); + if (linker.linkInModule(std::move(libdevice_module), + llvm::Linker::Flags::LinkOnlyNeeded)) { + return absl::InternalError("Failed to link libdevice"); + } + + return absl::OkStatus(); +} + +absl::StatusOr LLVMToPTX(mlir::ModuleOp module, + absl::string_view arch_name, + std::string libdevice_dir) { + // Based on translateLLVMIRToASM() in triton/python/src/llvm.cc + mlir::DialectRegistry registry; + mlir::registerBuiltinDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerNVVMDialectTranslation(registry); + module.getContext()->appendDialectRegistry(registry); + + llvm::LLVMContext llvmContext; + std::unique_ptr llvmModule = + mlir::translateModuleToLLVMIR(module, llvmContext); + if (!llvmModule) { + return absl::InternalError("Failed to emit LLVM IR"); + } + + auto cc = absl::StrReplaceAll(arch_name, {{".", ""}}); // "8.0" -> "80" + auto proc = absl::StrCat("sm_", cc, cc == "90" ? "a" : ""); + // We cap the ISA at 8.4 to align with Triton. + // See get_features() in triton/third_party/nvidia/backend/compiler.py. + auto features = cc >= "84" ? "+ptx84" : "+ptx" + cc; + llvmModule->setTargetTriple(llvm::Triple("nvptx64-nvidia-cuda")); + static absl::once_flag init_target_once; + absl::call_once(init_target_once, []() { + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + }); + + auto machineOrStatus = + CreateTargetMachine(llvmModule.get(), proc, + /*enable_fp_fusion=*/false, features); + if (!machineOrStatus.ok()) { + return machineOrStatus.status(); + } + auto machine = std::move(machineOrStatus.value()); + + llvmModule->setDataLayout(machine->createDataLayout()); + + auto needsLibdevice = + llvm::any_of(llvmModule->functions(), [](const auto &f) { + return !f.isIntrinsic() && f.isDeclaration() && + f.getName().starts_with("__nv_"); + }); + if (needsLibdevice) { + auto linkStatus = LinkLibdevice(llvmModule.get(), libdevice_dir); + if (!linkStatus.ok()) { + return linkStatus; + } + } + + auto transformer = mlir::makeOptimizingTransformer( + /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/machine.get()); + if (auto error = transformer(llvmModule.get()); error) { + return absl::InternalError("Failed to optimize LLVM IR"); + } + + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream bstream(stream); + llvm::legacy::PassManager pm; + machine->addPassesToEmitFile(pm, bstream, nullptr, + llvm::CodeGenFileType::AssemblyFile, + /*DisableVerify=*/false); + if (!pm.run(*llvmModule)) { + return absl::InternalError("Failed to compile LLVM IR to PTX"); + } + } + return result; +} + +} // namespace cuda + struct LowerTritonPass : public mlir::enzyme::impl::LowerTritonPassBase { using Base::Base; @@ -96,29 +253,103 @@ struct LowerTritonPass return; } - SmallVector tritonKernels; + DenseMap tritonKernels; SymbolTableCollection symbolTable; symbolTable.getSymbolTable(modOp); modOp->walk([&](triton_ext::TritonCallOp op) { collectTritonKernels(tritonKernels, symbolTable, op); }); - OpPassManager pm; + SmallVector clusterInfos; + + OpBuilder builder(modOp); + + bool anyFailed = false; + for (auto [ttCallOp, innerMod] : tritonKernels) { + int32_t numWarps = 4; + if (innerMod->hasAttrOfType("enzymexla.num_warps")) { + numWarps = innerMod->getAttrOfType("enzymexla.num_warps") + .getInt(); + } + int32_t numCtas = 1; + if (innerMod->hasAttrOfType("enzymexla.num_ctas")) { + numCtas = + innerMod->getAttrOfType("enzymexla.num_ctas").getInt(); + } + int32_t numStages = 3; + if (innerMod->hasAttrOfType("enzymexla.num_stages")) { + numStages = innerMod->getAttrOfType("enzymexla.num_stages") + .getInt(); + } - // TODO: bool rewrite_int4, bool allow_tma, int num_stages - xla::gpu::CreateTritonXlaPipeline(&pm, gpuCC, false, true, 1); + OpPassManager pm; - mlir::triton::nvidia_gpu::ClusterInfo out_cluster_info; - // TODO: int num_warps, int num_ctas, int num_stages - xla::gpu::CreateTritonPipeline(&pm, gpuCC, 4, 1, 1, out_cluster_info); + xla::gpu::CreateTritonXlaPipeline(&pm, gpuCC, true, true, numStages); + mlir::triton::nvidia_gpu::ClusterInfo clusterInfo; + xla::gpu::CreateTritonPipeline(&pm, gpuCC, numWarps, numCtas, numStages, + clusterInfo); + clusterInfos.push_back(clusterInfo); - for (auto tritonMod : tritonKernels) { - if (failed(runPipeline(pm, tritonMod))) { - tritonMod->emitError( + if (failed(runPipeline(pm, innerMod))) { + innerMod->emitError( "Failed to lower Triton kernel to TritonGPU kernel"); - signalPassFailure(); - return; + anyFailed = true; + continue; + } + + // int32_t threadsPerWarp = 32; + // if (innerMod->hasAttrOfType("ttg.threads_per_warp")) { + // threadsPerWarp = + // innerMod->getAttrOfType("ttg.threads_per_warp") + // .getInt(); + // } + + auto ptxOrError = + cuda::LLVMToPTX(innerMod, computeCapability, libdeviceDir); + if (!ptxOrError.ok()) { + innerMod->emitError(ptxOrError.status().message()); + anyFailed = true; + continue; } + + auto ptx = ptxOrError.value(); + llvm::errs() << "Compilation result: " << ptx << "\n"; + + builder.setInsertionPoint(ttCallOp); + + // auto sharedMemSizeAttr = + // innerMod->getAttrOfType("ttg.shared"); + // auto sharedMemSize = sharedMemSizeAttr.getInt(); + // auto shmemOpType = ttCallOp.getGridx().getType(); + // auto shmemOp = stablehlo::ConstantOp::create( + // builder, ttCallOp.getLoc(), shmemOpType, + // cast(makeAttr(shmemOpType, sharedMemSize))); + + // auto blockX = stablehlo::ConstantOp::create( + // builder, ttCallOp.getLoc(), shmemOpType, + // cast(makeAttr(shmemOpType, threadsPerWarp * + // numWarps))); + // auto blockYZ = stablehlo::ConstantOp::create( + // builder, ttCallOp.getLoc(), shmemOpType, + // cast(makeAttr(shmemOpType, 1))); + + // auto kernelCallOp = enzymexla::KernelCallOp::create( + // builder, ttCallOp.getLoc(), ttCallOp.getResultTypes(), + // ttCallOp.getFn(), ttCallOp.getGridx(), ttCallOp.getGridy(), + // ttCallOp.getGridz(), blockX, blockYZ, blockYZ, shmemOp, + // ttCallOp.getClusterx(), ttCallOp.getClustery(), + // ttCallOp.getClusterz(), ttCallOp.getInputs(), + // ttCallOp.getBackendConfigAttr(), ttCallOp.getOperandLayoutsAttr(), + // ttCallOp.getResultLayoutsAttr(), ttCallOp.getArgAttrsAttr(), + // ttCallOp.getResAttrsAttr(), ttCallOp.getOutputOperandAliasesAttr(), + // ttCallOp.getXlaSideEffectFreeAttr()); + // ttCallOp.replaceAllUsesWith(kernelCallOp); + // ttCallOp.erase(); + } + + if (anyFailed) { + signalPassFailure(); + return; } } }; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 52f3fec327..ac5bd49d97 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1096,6 +1096,12 @@ def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> { "ROCDL::ROCDLDialect", ]; let options = [ + Option< + /*C++ variable name=*/"libdeviceDir", + /*CLI argument=*/"libdevice_dir", + /*type=*/"std::string", + /*default=*/"\"\"", + /*description=*/"Path to the libdevice directory">, Option< /*C++ variable name=*/"backend", /*CLI argument=*/"backend", From 57938ebcbdc15568d19b31789b6f146a437c4f67 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Nov 2025 17:15:33 -0600 Subject: [PATCH 3/4] feat: lower to kernel_call --- src/enzyme_ad/jax/Passes/LowerTriton.cpp | 113 ++++++++++++++--------- 1 file changed, 71 insertions(+), 42 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerTriton.cpp b/src/enzyme_ad/jax/Passes/LowerTriton.cpp index 37518f0a9b..281443e220 100644 --- a/src/enzyme_ad/jax/Passes/LowerTriton.cpp +++ b/src/enzyme_ad/jax/Passes/LowerTriton.cpp @@ -297,54 +297,83 @@ struct LowerTritonPass continue; } - // int32_t threadsPerWarp = 32; - // if (innerMod->hasAttrOfType("ttg.threads_per_warp")) { - // threadsPerWarp = - // innerMod->getAttrOfType("ttg.threads_per_warp") - // .getInt(); + // remove divisibility attributes from the module before lowering to PTX + // auto funcOpInterface = dyn_cast( + // symbolTable.lookupNearestSymbolFrom(ttCallOp, + // ttCallOp.getFnAttr())); + + // if (!funcOpInterface) { + // innerMod->emitError("Failed to find function '") << ttCallOp.getFn() + // << + // "' in module"; + // anyFailed = true; + // continue; // } - auto ptxOrError = - cuda::LLVMToPTX(innerMod, computeCapability, libdeviceDir); - if (!ptxOrError.ok()) { - innerMod->emitError(ptxOrError.status().message()); - anyFailed = true; - continue; - } + // mlir::StringAttr divAttrName = + // builder.getStringAttr("tt.divisibility"); for (size_t i = 0; i < + // ttCallOp.getInputs().size(); ++i) { + // funcOpInterface.removeArgAttr(i, divAttrName); + // } - auto ptx = ptxOrError.value(); - llvm::errs() << "Compilation result: " << ptx << "\n"; + // auto ptxOrError = + // cuda::LLVMToPTX(innerMod, computeCapability, libdeviceDir); + // if (!ptxOrError.ok()) { + // innerMod->emitError(ptxOrError.status().message()); + // anyFailed = true; + // continue; + // } + + // auto ptx = ptxOrError.value(); + // llvm::errs() << "Compilation result: " << ptx << "\n"; + + int32_t threadsPerWarp = 32; + if (innerMod->hasAttrOfType("ttg.threads_per_warp")) { + threadsPerWarp = + innerMod->getAttrOfType("ttg.threads_per_warp") + .getInt(); + } builder.setInsertionPoint(ttCallOp); - // auto sharedMemSizeAttr = - // innerMod->getAttrOfType("ttg.shared"); - // auto sharedMemSize = sharedMemSizeAttr.getInt(); - // auto shmemOpType = ttCallOp.getGridx().getType(); - // auto shmemOp = stablehlo::ConstantOp::create( - // builder, ttCallOp.getLoc(), shmemOpType, - // cast(makeAttr(shmemOpType, sharedMemSize))); - - // auto blockX = stablehlo::ConstantOp::create( - // builder, ttCallOp.getLoc(), shmemOpType, - // cast(makeAttr(shmemOpType, threadsPerWarp * - // numWarps))); - // auto blockYZ = stablehlo::ConstantOp::create( - // builder, ttCallOp.getLoc(), shmemOpType, - // cast(makeAttr(shmemOpType, 1))); - - // auto kernelCallOp = enzymexla::KernelCallOp::create( - // builder, ttCallOp.getLoc(), ttCallOp.getResultTypes(), - // ttCallOp.getFn(), ttCallOp.getGridx(), ttCallOp.getGridy(), - // ttCallOp.getGridz(), blockX, blockYZ, blockYZ, shmemOp, - // ttCallOp.getClusterx(), ttCallOp.getClustery(), - // ttCallOp.getClusterz(), ttCallOp.getInputs(), - // ttCallOp.getBackendConfigAttr(), ttCallOp.getOperandLayoutsAttr(), - // ttCallOp.getResultLayoutsAttr(), ttCallOp.getArgAttrsAttr(), - // ttCallOp.getResAttrsAttr(), ttCallOp.getOutputOperandAliasesAttr(), - // ttCallOp.getXlaSideEffectFreeAttr()); - // ttCallOp.replaceAllUsesWith(kernelCallOp); - // ttCallOp.erase(); + auto sharedMemSizeAttr = + innerMod->getAttrOfType("ttg.shared"); + auto sharedMemSize = sharedMemSizeAttr.getInt(); + auto shmemOpType = ttCallOp.getGridx().getType(); + auto shmemOp = stablehlo::ConstantOp::create( + builder, ttCallOp.getLoc(), shmemOpType, + cast(makeAttr(shmemOpType, sharedMemSize))); + + auto blockX = stablehlo::ConstantOp::create( + builder, ttCallOp.getLoc(), shmemOpType, + cast(makeAttr(shmemOpType, threadsPerWarp * numWarps))); + auto blockYZ = stablehlo::ConstantOp::create( + builder, ttCallOp.getLoc(), shmemOpType, + cast(makeAttr(shmemOpType, 1))); + + SmallVector newInputs(ttCallOp.getInputs().begin(), + ttCallOp.getInputs().end()); + // we don't use the next 2 inputs + auto scratchSpace = stablehlo::ConstantOp::create( + builder, ttCallOp.getLoc(), + RankedTensorType::get({}, builder.getI8Type()), + cast( + makeAttr(RankedTensorType::get({}, builder.getI8Type()), 0))); + newInputs.push_back(scratchSpace); + newInputs.push_back(scratchSpace); + + auto kernelCallOp = enzymexla::KernelCallOp::create( + builder, ttCallOp.getLoc(), ttCallOp.getResultTypes(), + ttCallOp.getFn(), ttCallOp.getGridx(), ttCallOp.getGridy(), + ttCallOp.getGridz(), blockX, blockYZ, blockYZ, shmemOp, + ttCallOp.getClusterx(), ttCallOp.getClustery(), + ttCallOp.getClusterz(), newInputs, ttCallOp.getBackendConfigAttr(), + ttCallOp.getOperandLayoutsAttr(), ttCallOp.getResultLayoutsAttr(), + ttCallOp.getArgAttrsAttr(), ttCallOp.getResAttrsAttr(), + ttCallOp.getOutputOperandAliasesAttr(), + ttCallOp.getXlaSideEffectFreeAttr()); + ttCallOp.replaceAllUsesWith(kernelCallOp); + ttCallOp.erase(); } if (anyFailed) { From e873f6d034c8a92cc51ad6b4beac3db9ff077122 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 27 Nov 2025 07:39:26 -0600 Subject: [PATCH 4/4] chore: cleanup --- src/enzyme_ad/jax/BUILD | 6 - src/enzyme_ad/jax/Passes/LowerTriton.cpp | 182 ----------------------- src/enzyme_ad/jax/Passes/Passes.td | 6 - 3 files changed, 194 deletions(-) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 2fe4f4d2d6..5f175c0c62 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -929,8 +929,6 @@ cc_library( "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:ControlFlowToSCF", "@llvm-project//mlir:DLTIDialect", - "@llvm-project//mlir:ExecutionEngine", - "@llvm-project//mlir:ExecutionEngineUtils", "@llvm-project//mlir:FromLLVMIRTranslation", "@llvm-project//mlir:FromLLVMIRTranslationRegistration", "@llvm-project//mlir:FuncDialect", @@ -949,7 +947,6 @@ cc_library( "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", @@ -962,9 +959,7 @@ cc_library( "@llvm-project//mlir:NVGPUDialect", "@llvm-project//mlir:NVGPUToNVVM", "@llvm-project//mlir:NVVMDialect", - "@llvm-project//mlir:NVVMTarget", "@llvm-project//mlir:NVVMToLLVM", - "@llvm-project//mlir:NVVMToLLVMIRTranslation", "@llvm-project//mlir:OpenMPDialect", "@llvm-project//mlir:OpenMPToLLVM", "@llvm-project//mlir:Parser", @@ -979,7 +974,6 @@ cc_library( "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TargetLLVM", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:ToLLVMIRTranslationRegistration", diff --git a/src/enzyme_ad/jax/Passes/LowerTriton.cpp b/src/enzyme_ad/jax/Passes/LowerTriton.cpp index 281443e220..8a7664a685 100644 --- a/src/enzyme_ad/jax/Passes/LowerTriton.cpp +++ b/src/enzyme_ad/jax/Passes/LowerTriton.cpp @@ -16,9 +16,6 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" #include "nvidia/include/Dialect/NVWS/IR/Dialect.h" #include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" @@ -36,27 +33,6 @@ #include "src/enzyme_ad/jax/Utils.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_replace.h" - -#include "mlir/ExecutionEngine/OptUtils.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Module.h" -#include "llvm/IRReader/IRReader.h" -#include "llvm/Linker/Linker.h" -#include "llvm/MC/TargetRegistry.h" -#include "llvm/Support/CodeGen.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Target/TargetMachine.h" -#include "llvm/Target/TargetOptions.h" -#include "llvm/TargetParser/Triple.h" - #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" @@ -100,134 +76,6 @@ void collectTritonKernels( return; } -namespace cuda { - -namespace fs = std::filesystem; - -absl::StatusOr> -CreateTargetMachine(llvm::Module *module, absl::string_view arch_name, - bool enable_fp_fusion, absl::string_view features) { - // Based on createTargetMachine() in triton/python/src/llvm.cc - std::string error; - const auto *target = - llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); - if (target == nullptr) { - return absl::InternalError( - absl::StrFormat("Failed to lookup LLVM target based on triple %s: %s", - module->getTargetTriple().str(), error)); - } - llvm::TargetOptions opt; - if (enable_fp_fusion) { - opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; - } - opt.NoInfsFPMath = false; - opt.NoNaNsFPMath = true; - opt.TrapUnreachable = true; - opt.MCOptions.AsmVerbose = true; - opt.MCOptions.PreserveAsmComments = true; - return std::unique_ptr(target->createTargetMachine( - module->getTargetTriple(), arch_name, features, opt, llvm::Reloc::PIC_, - std::nullopt, llvm::CodeGenOptLevel::Aggressive)); -} - -absl::Status LinkLibdevice(llvm::Module *module, std::string libdevice_dir) { - auto libdevice_path = (fs::path(libdevice_dir) / "libdevice.10.bc").string(); - - llvm::LLVMContext &ctx = module->getContext(); - llvm::SMDiagnostic err; - std::unique_ptr libdevice_module = - llvm::parseIRFile(libdevice_path, err, ctx); - if (!libdevice_module) { - return absl::InternalError( - absl::StrFormat("Failed to parse libdevice IR file at %s: %s", - libdevice_path, err.getMessage())); - } - - llvm::Linker linker(*module); - if (linker.linkInModule(std::move(libdevice_module), - llvm::Linker::Flags::LinkOnlyNeeded)) { - return absl::InternalError("Failed to link libdevice"); - } - - return absl::OkStatus(); -} - -absl::StatusOr LLVMToPTX(mlir::ModuleOp module, - absl::string_view arch_name, - std::string libdevice_dir) { - // Based on translateLLVMIRToASM() in triton/python/src/llvm.cc - mlir::DialectRegistry registry; - mlir::registerBuiltinDialectTranslation(registry); - mlir::registerLLVMDialectTranslation(registry); - mlir::registerNVVMDialectTranslation(registry); - module.getContext()->appendDialectRegistry(registry); - - llvm::LLVMContext llvmContext; - std::unique_ptr llvmModule = - mlir::translateModuleToLLVMIR(module, llvmContext); - if (!llvmModule) { - return absl::InternalError("Failed to emit LLVM IR"); - } - - auto cc = absl::StrReplaceAll(arch_name, {{".", ""}}); // "8.0" -> "80" - auto proc = absl::StrCat("sm_", cc, cc == "90" ? "a" : ""); - // We cap the ISA at 8.4 to align with Triton. - // See get_features() in triton/third_party/nvidia/backend/compiler.py. - auto features = cc >= "84" ? "+ptx84" : "+ptx" + cc; - llvmModule->setTargetTriple(llvm::Triple("nvptx64-nvidia-cuda")); - static absl::once_flag init_target_once; - absl::call_once(init_target_once, []() { - LLVMInitializeNVPTXTarget(); - LLVMInitializeNVPTXTargetInfo(); - LLVMInitializeNVPTXTargetMC(); - LLVMInitializeNVPTXAsmPrinter(); - }); - - auto machineOrStatus = - CreateTargetMachine(llvmModule.get(), proc, - /*enable_fp_fusion=*/false, features); - if (!machineOrStatus.ok()) { - return machineOrStatus.status(); - } - auto machine = std::move(machineOrStatus.value()); - - llvmModule->setDataLayout(machine->createDataLayout()); - - auto needsLibdevice = - llvm::any_of(llvmModule->functions(), [](const auto &f) { - return !f.isIntrinsic() && f.isDeclaration() && - f.getName().starts_with("__nv_"); - }); - if (needsLibdevice) { - auto linkStatus = LinkLibdevice(llvmModule.get(), libdevice_dir); - if (!linkStatus.ok()) { - return linkStatus; - } - } - - auto transformer = mlir::makeOptimizingTransformer( - /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/machine.get()); - if (auto error = transformer(llvmModule.get()); error) { - return absl::InternalError("Failed to optimize LLVM IR"); - } - - std::string result; - { - llvm::raw_string_ostream stream(result); - llvm::buffer_ostream bstream(stream); - llvm::legacy::PassManager pm; - machine->addPassesToEmitFile(pm, bstream, nullptr, - llvm::CodeGenFileType::AssemblyFile, - /*DisableVerify=*/false); - if (!pm.run(*llvmModule)) { - return absl::InternalError("Failed to compile LLVM IR to PTX"); - } - } - return result; -} - -} // namespace cuda - struct LowerTritonPass : public mlir::enzyme::impl::LowerTritonPassBase { using Base::Base; @@ -297,36 +145,6 @@ struct LowerTritonPass continue; } - // remove divisibility attributes from the module before lowering to PTX - // auto funcOpInterface = dyn_cast( - // symbolTable.lookupNearestSymbolFrom(ttCallOp, - // ttCallOp.getFnAttr())); - - // if (!funcOpInterface) { - // innerMod->emitError("Failed to find function '") << ttCallOp.getFn() - // << - // "' in module"; - // anyFailed = true; - // continue; - // } - - // mlir::StringAttr divAttrName = - // builder.getStringAttr("tt.divisibility"); for (size_t i = 0; i < - // ttCallOp.getInputs().size(); ++i) { - // funcOpInterface.removeArgAttr(i, divAttrName); - // } - - // auto ptxOrError = - // cuda::LLVMToPTX(innerMod, computeCapability, libdeviceDir); - // if (!ptxOrError.ok()) { - // innerMod->emitError(ptxOrError.status().message()); - // anyFailed = true; - // continue; - // } - - // auto ptx = ptxOrError.value(); - // llvm::errs() << "Compilation result: " << ptx << "\n"; - int32_t threadsPerWarp = 32; if (innerMod->hasAttrOfType("ttg.threads_per_warp")) { threadsPerWarp = diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index ac5bd49d97..52f3fec327 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1096,12 +1096,6 @@ def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> { "ROCDL::ROCDLDialect", ]; let options = [ - Option< - /*C++ variable name=*/"libdeviceDir", - /*CLI argument=*/"libdevice_dir", - /*type=*/"std::string", - /*default=*/"\"\"", - /*description=*/"Path to the libdevice directory">, Option< /*C++ variable name=*/"backend", /*CLI argument=*/"backend",