Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
attr-dict
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
}];
let hasCanonicalizer = 1;
}
#endif // AMDGPU
142 changes: 142 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
Expand All @@ -28,6 +29,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

#include <cstdint>
#include <limits>
#include <optional>

Expand Down Expand Up @@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ScaledMFMAOp
//===----------------------------------------------------------------------===//

namespace {
/// Check if the scales input is used in other scaled mfma's while they exist.
/// If theyre unused then pack the scales.
struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ScaledMFMAOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// If this use of a scale has a non zero opsel, packing has already been
// done.
auto checkIfUnpackable = [&](OpOperand &op) {
if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
switch (op.getOperandNumber()) {
case 3:
return smfma.getScalesIdxA() != 0;
break;
case 4:
return smfma.getScalesIdxB() != 0;
break;
default:
return true;
break;
}
}
};

auto setOpsel = [&](unsigned idx, int64_t val) {
switch (idx) {
case 3:
return op.setScalesIdxA(val);
break;
case 4:
return op.setScalesIdxB(val);
break;
default:
break;
}
};

// Obtain flat index from offsets and shape.
auto getIdxFromExtract = [](vector::ExtractOp op) {
ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
int cumul = 1;
int idx = 0;
for (auto [offset, size] :
reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
idx += offset * cumul;
cumul *= size;
}
return idx;
};

// Obtain offsets for new shape from flat index.
auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
SmallVector<int64_t> res;
ShapedType shapedty = static_cast<ShapedType>(ty);
int64_t numElements = shapedty.getNumElements();
for (auto size : shapedty.getShape()) {
numElements /= size;
res.push_back(idx / numElements);
idx -= (idx / numElements) * size;
}
return res;
};

// For every scale operand of this ScaledMFMAOp, if the scale follows the
// following pattern:
//
// %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
// %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
// amdgpu.scaled_mfma(%scale[0] * ...
//
// rewrite to:
//
// %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
// %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
// amdgpu.scaled_mfma(%scale[0-3] * ...
//
// This creates duplicate shape_casts for every use but these will be removed in CSE.
for (auto opIdx : SmallVector<int64_t>({3, 4})) {
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
if (!insertOp) {
return failure();
}
if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
return failure();
}

auto extractOp =
insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
if (!extractOp) {
return failure();
}

Value scaleSrc = extractOp.getOperand(0);
auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
if (!stype) {
return failure();
}
// We do not handle dynamic dims yet, assume that the input is padded to
// a static shape now.
if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
[&](int64_t i) { return stype.isDynamicDim(i); })) {
return failure();
}

int64_t numElements = stype.getNumElements();
if (numElements <= 4) {
return failure();
}

Type newSrcType = VectorType::get(
SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
Value newScaleSrc =
rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
int64_t idx = getIdxFromExtract(extractOp);
SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
auto scaleTy = VectorType::get({4}, stype.getElementType());
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
op.setOperand(opIdx, scale);
setOpsel(opIdx, offsets[1]);
}
return success();
}
};
} // namespace

void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<PackScales>(context);
}

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
MLIRVectorDialect
MLIRIR
MLIRSideEffectInterfaces
MLIRMemRefUtils
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Dialect/AMDGPU/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
func.return
}

// -----

// CHECK-LABEL: func @scaled_mfma
// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
%scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
}
Loading