Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,50 @@ using namespace mlir::enzyme;
using namespace mlir::enzymexla;
using namespace mlir::stablehlo;

// Helper function to extract constant scalar value (real/imag parts)
static bool extractConstantScalar(Value val, double &realPart,
double &imagPart) {
DenseElementsAttr attr;
if (!matchPattern(val, m_Constant(&attr)))
return false;

auto valType = cast<RankedTensorType>(val.getType());
auto elemType = valType.getElementType();

if (auto complexType = dyn_cast<ComplexType>(elemType)) {
// Complex scalar
auto complexVal = attr.getSplatValue<std::complex<APFloat>>();
realPart = complexVal.real().convertToDouble();
imagPart = complexVal.imag().convertToDouble();
return true;
} else if (isa<FloatType>(elemType)) {
// Real scalar
realPart = attr.getSplatValue<APFloat>().convertToDouble();
imagPart = 0.0;
return true;
}
return false;
}

// Helper function to create operand and rank for scalar value
// Returns the operand to use and the rank (1 for empty placeholder, 0 for
// scalar)
static std::pair<Value, int64_t> createScalarOperand(PatternRewriter &rewriter,
Location loc,
Value originalVal,
bool useAttribute) {
if (useAttribute) {
// Create an empty 0-element tensor as placeholder
auto emptyType = RankedTensorType::get(
{0}, cast<RankedTensorType>(originalVal.getType()).getElementType());
auto emptyTensor = stablehlo::ConstantOp::create(
rewriter, loc,
DenseElementsAttr::get(emptyType, ArrayRef<Attribute>{}));
return {emptyTensor, 1};
}
return {originalVal, 0};
}

struct SyrkOpLowering : public OpRewritePattern<enzymexla::SyrkOp> {
using OpRewritePattern<enzymexla::SyrkOp>::OpRewritePattern;

Expand All @@ -39,6 +83,8 @@ struct SyrkOpLowering : public OpRewritePattern<enzymexla::SyrkOp> {
PatternRewriter &rewriter) const override {
if (backend == "cpu")
return matchAndRewriteCPU(op, rewriter);
if (backend == "cuda")
return matchAndRewriteCUDA(op, rewriter);

return matchAndRewriteFallback(op, rewriter);
}
Expand Down Expand Up @@ -248,7 +294,93 @@ struct SyrkOpLowering : public OpRewritePattern<enzymexla::SyrkOp> {
return success();
}

// TODO: gpu lowering after we register the cublas functions via XLA FFI
LogicalResult matchAndRewriteCUDA(enzymexla::SyrkOp op,
PatternRewriter &rewriter) const {
auto CType = cast<RankedTensorType>(op.getC().getType());
auto rank = CType.getRank();

bool isComplex = false;
if (auto complex_type = dyn_cast<ComplexType>(CType.getElementType())) {
isComplex = true;
}

if (isComplex && op.getTranspose() == enzymexla::LapackTranspose::adjoint) {
return rewriter.notifyMatchFailure(
op, "Complex matrix not supported for complex transpose");
}

bool transpose = op.getTranspose() != enzymexla::LapackTranspose::none;

// Try to extract alpha and beta as constants
double alphaReal = 0.0, alphaImag = 0.0;
double betaReal = 0.0, betaImag = 0.0;
bool useAlphaAttr =
extractConstantScalar(op.getAlpha(), alphaReal, alphaImag);
bool useBetaAttr = extractConstantScalar(op.getBeta(), betaReal, betaImag);

// Build operands list - use empty tensors for constant alpha/beta
SmallVector<Value> operands;
operands.push_back(op.getA());
operands.push_back(op.getC());

SmallVector<int64_t> operandRanks = {rank, rank};

auto [alphaOperand, alphaRank] =
createScalarOperand(rewriter, op.getLoc(), op.getAlpha(), useAlphaAttr);
operands.push_back(alphaOperand);
operandRanks.push_back(alphaRank);

auto [betaOperand, betaRank] =
createScalarOperand(rewriter, op.getLoc(), op.getBeta(), useBetaAttr);
operands.push_back(betaOperand);
operandRanks.push_back(betaRank);

auto customCall = stablehlo::CustomCallOp::create(
rewriter, op.getLoc(), TypeRange{CType}, operands,
rewriter.getStringAttr("reactant_cublas_syrk_ffi"),
/*has_side_effect*/ nullptr,
/*backend_config*/
rewriter.getDictionaryAttr({
rewriter.getNamedAttr("transpose", rewriter.getBoolAttr(transpose)),
rewriter.getNamedAttr(
"uplo",
rewriter.getBoolAttr(op.getUplo() == enzymexla::LapackUplo::U)),
rewriter.getNamedAttr("use_alpha_attribute",
rewriter.getBoolAttr(useAlphaAttr)),
rewriter.getNamedAttr("use_beta_attribute",
rewriter.getBoolAttr(useBetaAttr)),
rewriter.getNamedAttr("alpha_real",
rewriter.getF64FloatAttr(alphaReal)),
rewriter.getNamedAttr("alpha_imag",
rewriter.getF64FloatAttr(alphaImag)),
rewriter.getNamedAttr("beta_real",
rewriter.getF64FloatAttr(betaReal)),
rewriter.getNamedAttr("beta_imag",
rewriter.getF64FloatAttr(betaImag)),
}),
/*api_version*/
stablehlo::CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
/*calledcomputations*/ nullptr,
/*operand_layouts*/
getSHLOLayout(rewriter, operandRanks, SmallVector<bool>(4, true), rank),
/*result_layouts*/
getSHLOLayout(rewriter, {rank}, SmallVector<bool>(rank, true), rank),
/*output_operand_aliases*/
rewriter.getArrayAttr({
stablehlo::OutputOperandAliasAttr::get(op.getContext(), {}, 1, {}),
}));

auto result = customCall.getResult(0);
if (op.getFill() || op.getUplo() == enzymexla::LapackUplo::L) {
result = stablehlo::copyTriangularPart(rewriter, result,
enzymexla::LapackUplo::U);
}
rewriter.replaceAllUsesWith(op.getResult(), result);

return success();
}

LogicalResult matchAndRewriteFallback(enzymexla::SyrkOp op,
PatternRewriter &rewriter) const {
Expand Down
63 changes: 63 additions & 0 deletions test/lit_tests/linalg/syrk.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: enzymexlamlir-opt --lower-enzymexla-blas="backend=cpu" --enzyme-hlo-opt %s | FileCheck %s --check-prefix=CPU
// RUN: enzymexlamlir-opt --lower-enzymexla-blas="backend=cuda" --enzyme-hlo-opt %s | FileCheck %s --check-prefix=CUDA
// RUN: enzymexlamlir-opt --lower-enzymexla-blas="backend=tpu" --enzyme-hlo-opt %s | FileCheck %s --check-prefix=TPU

module {
Expand Down Expand Up @@ -36,6 +37,18 @@ module {
// CPU-NEXT: return %5 : tensor<64x64xf32>
// CPU-NEXT: }

// CUDA: func.func @main1(%arg0: tensor<64x32xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> {
// CUDA-NEXT: %0 = tensor.empty() : tensor<0xf32>
// CUDA-NEXT: %1 = tensor.empty() : tensor<0xf32>
// CUDA-NEXT: %2 = stablehlo.custom_call @reactant_cublas_syrk_ffi(%arg0, %arg1, %0, %1) {api_version = 4 : i32, backend_config = {alpha_imag = 0.000000e+00 : f64, alpha_real = 2.000000e+00 : f64, beta_imag = 0.000000e+00 : f64, beta_real = 3.000000e+00 : f64, transpose = false, uplo = true, use_alpha_attribute = true, use_beta_attribute = true}, enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>], operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<0xf32>, tensor<0xf32>) -> tensor<64x64xf32>
// CUDA-NEXT: %3 = stablehlo.iota dim = 0 : tensor<64x64xi32>
// CUDA-NEXT: %4 = stablehlo.iota dim = 1 : tensor<64x64xi32>
// CUDA-NEXT: %5 = stablehlo.compare LT, %3, %4 : (tensor<64x64xi32>, tensor<64x64xi32>) -> tensor<64x64xi1>
// CUDA-NEXT: %6 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
// CUDA-NEXT: %7 = stablehlo.select %5, %2, %6 : tensor<64x64xi1>, tensor<64x64xf32>
// CUDA-NEXT: return %7 : tensor<64x64xf32>
// CUDA-NEXT: }

// TPU: func.func @main1(%arg0: tensor<64x32xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> {
// TPU-NEXT: %cst = stablehlo.constant dense<3.000000e+00> : tensor<64x64xf32>
// TPU-NEXT: %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<64x64xf32>
Expand Down Expand Up @@ -81,6 +94,18 @@ module {
// CPU-NEXT: return %0 : tensor<64x64xf32>
// CPU-NEXT: }

// CUDA: func.func @main2(%arg0: tensor<64x32xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> {
// CUDA-NEXT: %0 = tensor.empty() : tensor<0xf32>
// CUDA-NEXT: %1 = tensor.empty() : tensor<0xf32>
// CUDA-NEXT: %2 = stablehlo.custom_call @reactant_cublas_syrk_ffi(%arg0, %arg1, %0, %1) {api_version = 4 : i32, backend_config = {alpha_imag = 0.000000e+00 : f64, alpha_real = 2.000000e+00 : f64, beta_imag = 0.000000e+00 : f64, beta_real = 3.000000e+00 : f64, transpose = false, uplo = false, use_alpha_attribute = true, use_beta_attribute = true}, enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>], operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<0xf32>, tensor<0xf32>) -> tensor<64x64xf32>
// CUDA-NEXT: %3 = stablehlo.iota dim = 0 : tensor<64x64xi32>
// CUDA-NEXT: %4 = stablehlo.iota dim = 1 : tensor<64x64xi32>
// CUDA-NEXT: %5 = stablehlo.compare LT, %3, %4 : (tensor<64x64xi32>, tensor<64x64xi32>) -> tensor<64x64xi1>
// CUDA-NEXT: %6 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
// CUDA-NEXT: %7 = stablehlo.select %5, %2, %6 : tensor<64x64xi1>, tensor<64x64xf32>
// CUDA-NEXT: return %7 : tensor<64x64xf32>
// CUDA-NEXT: }

// TPU: func.func @main2(%arg0: tensor<64x32xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> {
// TPU-NEXT: %cst = stablehlo.constant dense<3.000000e+00> : tensor<64x64xf32>
// TPU-NEXT: %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<64x64xf32>
Expand Down Expand Up @@ -126,6 +151,13 @@ module {
// CPU-NEXT: return %0 : tensor<64x64xf32>
// CPU-NEXT: }

// CUDA: func.func @main3(%arg0: tensor<64x32xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> {
// CUDA-NEXT: %0 = tensor.empty() : tensor<0xf32>
// CUDA-NEXT: %1 = tensor.empty() : tensor<0xf32>
// CUDA-NEXT: %2 = stablehlo.custom_call @reactant_cublas_syrk_ffi(%arg0, %arg1, %0, %1) {api_version = 4 : i32, backend_config = {alpha_imag = 0.000000e+00 : f64, alpha_real = 2.000000e+00 : f64, beta_imag = 0.000000e+00 : f64, beta_real = 3.000000e+00 : f64, transpose = false, uplo = false, use_alpha_attribute = true, use_beta_attribute = true}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<0xf32>, tensor<0xf32>) -> tensor<64x64xf32>
// CUDA-NEXT: return %2 : tensor<64x64xf32>
// CUDA-NEXT: }

// TPU: func.func @main3(%arg0: tensor<64x32xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> {
// TPU-NEXT: %cst = stablehlo.constant dense<3.000000e+00> : tensor<64x64xf32>
// TPU-NEXT: %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<64x64xf32>
Expand Down Expand Up @@ -177,10 +209,41 @@ module {
// CPU-NEXT: return %4 : tensor<4x4xf32>
// CPU-NEXT: }

// CUDA: func.func @main4(%arg0: tensor<5x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
// CUDA-NEXT{LITERAL}: %c = stablehlo.constant dense<[[false, true, true, true], [false, false, true, true], [false, false, false, true], [false, false, false, false]]> : tensor<4x4xi1>
// CUDA-NEXT: %cst = stablehlo.constant dense<5.000000e+00> : tensor<4x4xf32>
// CUDA-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<4x4xf32>
// CUDA-NEXT: %0 = tensor.empty() : tensor<0xf32>
// CUDA-NEXT: %1 = tensor.empty() : tensor<0xf32>
// CUDA-NEXT: %2 = stablehlo.custom_call @reactant_cublas_syrk_ffi(%arg0, %cst_0, %0, %1) {api_version = 4 : i32, backend_config = {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta_imag = 0.000000e+00 : f64, beta_real = 0.000000e+00 : f64, transpose = true, uplo = true, use_alpha_attribute = true, use_beta_attribute = true}, enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>], operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<5x4xf32>, tensor<4x4xf32>, tensor<0xf32>, tensor<0xf32>) -> tensor<4x4xf32>
// CUDA-NEXT: %3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<4x4xf32>) -> tensor<4x4xf32>
// CUDA-NEXT: %4 = stablehlo.select %c, %2, %3 : tensor<4x4xi1>, tensor<4x4xf32>
// CUDA-NEXT: %5 = stablehlo.multiply %cst, %arg1 : tensor<4x4xf32>
// CUDA-NEXT: %6 = stablehlo.add %4, %5 : tensor<4x4xf32>
// CUDA-NEXT: return %6 : tensor<4x4xf32>
// CUDA-NEXT: }

// TPU: func.func @main4(%arg0: tensor<5x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
// TPU-NEXT: %cst = stablehlo.constant dense<5.000000e+00> : tensor<4x4xf32>
// TPU-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0] : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<4x4xf32>
// TPU-NEXT: %1 = stablehlo.multiply %cst, %arg1 : tensor<4x4xf32>
// TPU-NEXT: %2 = stablehlo.add %0, %1 : tensor<4x4xf32>
// TPU-NEXT: return %2 : tensor<4x4xf32>
// TPU-NEXT: }

module {
func.func @main(%arg0: tensor<64x32xf32>, %arg1: tensor<64x64xf32>, %alpha: tensor<f32>, %beta: tensor<f32>) -> tensor<64x64xf32> {
%0 = enzymexla.blas.syrk %arg0, %arg1, %alpha, %beta {fill, transpose = #enzymexla.transpose<none>, uplo = #enzymexla.uplo<U>} : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<f32>, tensor<f32>) -> tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}
}

// CUDA: func.func @main(%arg0: tensor<64x32xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<64x64xf32> {
// CUDA-NEXT: %0 = stablehlo.custom_call @reactant_cublas_syrk_ffi(%arg0, %arg1, %arg2, %arg3) {api_version = 4 : i32, backend_config = {alpha_imag = 0.000000e+00 : f64, alpha_real = 0.000000e+00 : f64, beta_imag = 0.000000e+00 : f64, beta_real = 0.000000e+00 : f64, transpose = false, uplo = true, use_alpha_attribute = false, use_beta_attribute = false}, enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>], operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<f32>, tensor<f32>) -> tensor<64x64xf32>
// CUDA-NEXT: %1 = stablehlo.iota dim = 0 : tensor<64x64xi32>
// CUDA-NEXT: %2 = stablehlo.iota dim = 1 : tensor<64x64xi32>
// CUDA-NEXT: %3 = stablehlo.compare LT, %1, %2 : (tensor<64x64xi32>, tensor<64x64xi32>) -> tensor<64x64xi1>
// CUDA-NEXT: %4 = stablehlo.transpose %0, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
// CUDA-NEXT: %5 = stablehlo.select %3, %0, %4 : tensor<64x64xi1>, tensor<64x64xf32>
// CUDA-NEXT: return %5 : tensor<64x64xf32>
// CUDA-NEXT: }
Loading