Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ common
common --define framework_shared_object=true
common --define tsl_protobuf_header_only=true
common --define=allow_oversize_protos=true
common --check_visibility=false


# Some targets have the same py source file, but use different
Expand Down
36 changes: 36 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -892,17 +892,23 @@ cc_library(
"//src/external/isl:Isl",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@enzyme//:EnzymeMLIR",
"@jax//jaxlib/gpu:triton_cc_proto",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:ExecutionEngine",
"@llvm-project//llvm:IPO",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:InstCombine",
"@llvm-project//llvm:Linker",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:OrcTargetProcess",
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Scalar",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:TargetParser",
"@llvm-project//mlir:AffineAnalysis",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineToStandard",
Expand All @@ -913,6 +919,7 @@ cc_library(
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:ArithUtils",
"@llvm-project//mlir:AsyncDialect",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:BytecodeOpInterface",
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:CommonFolders",
Expand Down Expand Up @@ -992,10 +999,36 @@ cc_library(
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_passes",
"@stablehlo//:stablehlo_type_inference",
"@triton//:GluonDialect",
"@triton//:TritonDialects",
"@triton//:TritonGPUToLLVM",
"@triton//:TritonGPUTransforms",
"@triton//:TritonLLVMIR",
"@triton//:TritonNvidiaGPUTransforms",
"@triton//:TritonToTritonGPU",
"@triton//:TritonToTritonGPUPasses",
"@triton//:TritonTransforms",
"@triton//:WarpSpecialization",
"@triton//:triton_conversion_triton_to_triton_gpu_passes_inc_gen",
"@triton//:triton_nvidia_gpu_transforms_inc_gen",
"@triton//third_party/amd:TritonAMDGPU",
"@triton//third_party/amd:TritonAMDGPUToLLVM",
"@triton//third_party/amd:TritonAMDGPUTransforms",
"@triton//third_party/nvidia:NVGPUDialect",
"@triton//third_party/nvidia:NVGPUToLLVM",
"@triton//third_party/nvidia:NVHopperTransforms",
"@triton//third_party/nvidia:NVWSDialect",
"@triton//third_party/nvidia:NVWSTransforms",
"@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
"@triton//third_party/proton:ProtonGPUIR",
"@triton//third_party/proton:ProtonIR",
"@xla//xla/backends/gpu/codegen/triton:compilation_pipeline",
"@xla//xla/mlir/utils:type_util",
"@xla//xla/mlir_hlo",
"@xla//xla/pjrt:triton",
"@xla//xla/stream_executor:device_description",
"@xla//xla/stream_executor/cuda:cuda_compute_capability",
"@xla//xla/stream_executor/rocm:rocm_compute_capability",
"@zlib",
],
)
Expand Down Expand Up @@ -1095,6 +1128,7 @@ cc_library(

# Triton
"@triton//:TritonDialects",
"@triton//:GluonDialect",
"@triton//:TritonGPUToLLVM",
"@triton//:TritonGPUTransforms",
"@triton//:TritonLLVMIR",
Expand All @@ -1114,6 +1148,8 @@ cc_library(
"@triton//third_party/nvidia:NVWSDialect",
"@triton//third_party/nvidia:NVWSTransforms",
"@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
"@triton//third_party/proton:ProtonIR",
"@triton//third_party/proton:ProtonGPUIR",

# Shardy stuff
"@shardy//shardy/dialect/sdy/ir:dialect",
Expand Down
202 changes: 202 additions & 0 deletions src/enzyme_ad/jax/Passes/LowerTriton.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include <filesystem>

#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "xla/backends/gpu/codegen/triton/compilation_pipeline.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/rocm/rocm_compute_capability.h"

#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h"
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h"
#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Ops.h"
#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h"
#include "triton/Dialect/Gluon/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

#include "xla/pjrt/triton.h"

#include "src/enzyme_ad/jax/Utils.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"

#define DEBUG_TYPE "lower-triton"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_LOWERTRITONPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::enzyme;
using namespace mlir::enzymexla;
using namespace mlir::enzymexla::triton_ext;

void collectTritonKernels(
DenseMap<triton_ext::TritonCallOp, ModuleOp> &tritonKernels,
SymbolTableCollection &symbolTable, triton_ext::TritonCallOp op) {
auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr());
if (!funcOp) {
op->emitError() << "Failed to find function '" << op.getFn() << "' in "
<< "module";
return;
}

auto wrappedMod = funcOp->getParentOfType<ModuleOp>();
if (!wrappedMod) {
op->emitError() << "Failed to find parent built-in module.";
return;
}

auto ttModOP = wrappedMod->getParentOfType<triton_ext::TritonModuleOp>();
if (!ttModOP) {
op->emitError() << "No `triton_ext.module` found!";
return;
}

tritonKernels[op] = wrappedMod;
return;
}

struct LowerTritonPass
: public mlir::enzyme::impl::LowerTritonPassBase<LowerTritonPass> {
using Base::Base;

void runOnOperation() override {
auto modOp = getOperation();

stream_executor::GpuComputeCapability gpuCC;
if (backend == "cuda") {
auto cudaCC =
stream_executor::CudaComputeCapability::FromString(computeCapability);
if (!cudaCC.ok()) {
modOp->emitError("Unsupported cuda compute capability: ")
<< cudaCC.status().ToString();
return;
}
gpuCC = stream_executor::GpuComputeCapability(cudaCC.value());
} else if (backend == "rocm") {
auto rocmCC = stream_executor::RocmComputeCapability(computeCapability);
gpuCC = stream_executor::GpuComputeCapability(rocmCC);
} else {
modOp->emitError("Unsupported backend: ") << backend;
return;
}

DenseMap<triton_ext::TritonCallOp, ModuleOp> tritonKernels;
SymbolTableCollection symbolTable;
symbolTable.getSymbolTable(modOp);
modOp->walk([&](triton_ext::TritonCallOp op) {
collectTritonKernels(tritonKernels, symbolTable, op);
});

SmallVector<mlir::triton::nvidia_gpu::ClusterInfo> clusterInfos;

OpBuilder builder(modOp);

bool anyFailed = false;
for (auto [ttCallOp, innerMod] : tritonKernels) {
int32_t numWarps = 4;
if (innerMod->hasAttrOfType<IntegerAttr>("enzymexla.num_warps")) {
numWarps = innerMod->getAttrOfType<IntegerAttr>("enzymexla.num_warps")
.getInt();
}
int32_t numCtas = 1;
if (innerMod->hasAttrOfType<IntegerAttr>("enzymexla.num_ctas")) {
numCtas =
innerMod->getAttrOfType<IntegerAttr>("enzymexla.num_ctas").getInt();
}
int32_t numStages = 3;
if (innerMod->hasAttrOfType<IntegerAttr>("enzymexla.num_stages")) {
numStages = innerMod->getAttrOfType<IntegerAttr>("enzymexla.num_stages")
.getInt();
}

OpPassManager pm;

xla::gpu::CreateTritonXlaPipeline(&pm, gpuCC, true, true, numStages);
mlir::triton::nvidia_gpu::ClusterInfo clusterInfo;
xla::gpu::CreateTritonPipeline(&pm, gpuCC, numWarps, numCtas, numStages,
clusterInfo);
clusterInfos.push_back(clusterInfo);

if (failed(runPipeline(pm, innerMod))) {
innerMod->emitError(
"Failed to lower Triton kernel to TritonGPU kernel");
anyFailed = true;
continue;
}

int32_t threadsPerWarp = 32;
if (innerMod->hasAttrOfType<IntegerAttr>("ttg.threads_per_warp")) {
threadsPerWarp =
innerMod->getAttrOfType<IntegerAttr>("ttg.threads_per_warp")
.getInt();
}

builder.setInsertionPoint(ttCallOp);

auto sharedMemSizeAttr =
innerMod->getAttrOfType<IntegerAttr>("ttg.shared");
auto sharedMemSize = sharedMemSizeAttr.getInt();
auto shmemOpType = ttCallOp.getGridx().getType();
auto shmemOp = stablehlo::ConstantOp::create(
builder, ttCallOp.getLoc(), shmemOpType,
cast<ElementsAttr>(makeAttr(shmemOpType, sharedMemSize)));

auto blockX = stablehlo::ConstantOp::create(
builder, ttCallOp.getLoc(), shmemOpType,
cast<ElementsAttr>(makeAttr(shmemOpType, threadsPerWarp * numWarps)));
auto blockYZ = stablehlo::ConstantOp::create(
builder, ttCallOp.getLoc(), shmemOpType,
cast<ElementsAttr>(makeAttr(shmemOpType, 1)));

SmallVector<mlir::Value> newInputs(ttCallOp.getInputs().begin(),
ttCallOp.getInputs().end());
// we don't use the next 2 inputs
auto scratchSpace = stablehlo::ConstantOp::create(
builder, ttCallOp.getLoc(),
RankedTensorType::get({}, builder.getI8Type()),
cast<ElementsAttr>(
makeAttr(RankedTensorType::get({}, builder.getI8Type()), 0)));
newInputs.push_back(scratchSpace);
newInputs.push_back(scratchSpace);

auto kernelCallOp = enzymexla::KernelCallOp::create(
builder, ttCallOp.getLoc(), ttCallOp.getResultTypes(),
ttCallOp.getFn(), ttCallOp.getGridx(), ttCallOp.getGridy(),
ttCallOp.getGridz(), blockX, blockYZ, blockYZ, shmemOp,
ttCallOp.getClusterx(), ttCallOp.getClustery(),
ttCallOp.getClusterz(), newInputs, ttCallOp.getBackendConfigAttr(),
ttCallOp.getOperandLayoutsAttr(), ttCallOp.getResultLayoutsAttr(),
ttCallOp.getArgAttrsAttr(), ttCallOp.getResAttrsAttr(),
ttCallOp.getOutputOperandAliasesAttr(),
ttCallOp.getXlaSideEffectFreeAttr());
ttCallOp.replaceAllUsesWith(kernelCallOp);
ttCallOp.erase();
}

if (anyFailed) {
signalPassFailure();
return;
}
}
};
41 changes: 41 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,47 @@ def ConvertTritonToTritonGPUPreservingModuleAttributesPass : Pass<
>];
}

def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> {
let summary = "Lower Triton to kernel call";
let dependentDialects = [
"triton::TritonDialect",
"enzymexla::EnzymeXLADialect",
"func::FuncDialect",
"enzymexla::triton_ext::TritonExtDialect",
"stablehlo::StablehloDialect",
"triton::nvidia_gpu::TritonNvidiaGPUDialect",
"triton::instrument::TritonInstrumentDialect",
"triton::nvgpu::NVGPUDialect",
"triton::nvws::NVWSDialect",
"triton::amdgpu::TritonAMDGPUDialect",
"triton::proton::ProtonDialect",
"triton::proton::gpu::ProtonGPUDialect",
"triton::gluon::GluonDialect",
"triton::gpu::TritonGPUDialect",
"cf::ControlFlowDialect",
"math::MathDialect",
"arith::ArithDialect",
"scf::SCFDialect",
"gpu::GPUDialect",
"LLVM::LLVMDialect",
"NVVM::NVVMDialect",
"ROCDL::ROCDLDialect",
];
let options = [
Option<
/*C++ variable name=*/"backend",
/*CLI argument=*/"backend",
/*type=*/"std::string",
/*default=*/"\"cuda\"",
/*description=*/"HW backend">,
Option<
/*C++ variable name=*/"computeCapability",
/*CLI argument=*/"compute_capability",
/*type=*/"std::string",
/*default=*/"\"8.0\"",
/*description=*/"Compute capability">];
}

def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> {
let summary = "Legalize batching specific enzyme ops to stablehlo dialect";
let dependentDialects = [
Expand Down
21 changes: 21 additions & 0 deletions src/enzyme_ad/jax/RegistryUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,21 @@
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_export.h"
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h"

#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h"
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
#include "nvidia/include/NVGPUToLLVM/Passes.h"
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h"
#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
#include "triton/Dialect/Gluon/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "triton/Target/LLVMIR/Passes.h"
Expand Down Expand Up @@ -218,6 +225,13 @@ void registerDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::ub::UBDialect>();
registry.insert<mlir::triton::TritonDialect>();
registry.insert<mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect>();
registry.insert<mlir::triton::instrument::TritonInstrumentDialect>();
registry.insert<mlir::triton::nvgpu::NVGPUDialect>();
registry.insert<mlir::triton::nvws::NVWSDialect>();
registry.insert<mlir::triton::amdgpu::TritonAMDGPUDialect>();
registry.insert<mlir::triton::proton::ProtonDialect>();
registry.insert<mlir::triton::proton::gpu::ProtonGPUDialect>();
registry.insert<mlir::triton::gluon::GluonDialect>();
registry.insert<mlir::triton::gpu::TritonGPUDialect>();
}

Expand Down Expand Up @@ -256,6 +270,13 @@ void loadAllRegisteredDialects(mlir::MLIRContext &context) {
context.loadDialect<mlir::ub::UBDialect>();
context.loadDialect<mlir::triton::TritonDialect>();
context.loadDialect<mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect>();
context.loadDialect<mlir::triton::instrument::TritonInstrumentDialect>();
context.loadDialect<mlir::triton::nvgpu::NVGPUDialect>();
context.loadDialect<mlir::triton::nvws::NVWSDialect>();
context.loadDialect<mlir::triton::amdgpu::TritonAMDGPUDialect>();
context.loadDialect<mlir::triton::proton::ProtonDialect>();
context.loadDialect<mlir::triton::proton::gpu::ProtonGPUDialect>();
context.loadDialect<mlir::triton::gluon::GluonDialect>();
context.loadDialect<mlir::triton::gpu::TritonGPUDialect>();
}

Expand Down
Loading
Loading