Skip to content

Commit 4287c72

Browse files
authored
[MLIR][NVVM] Add tcgen05 alloc/dealloc Ops (#125674)
PR #124961 adds intrinsics for the tcgen05 alloc/dealloc PTX instructions. This patch adds NVVM Ops for the same. Tests are added to verify the lowering to the corresponding intrinsics in tcgen05-alloc.mlir file. PTX ISA link: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions Signed-off-by: Durgadoss R <[email protected]>
1 parent 76d1cb2 commit 4287c72

File tree

4 files changed

+193
-1
lines changed

4 files changed

+193
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ enum NVVMMemorySpace {
3939
/// Shared memory space identifier.
4040
kSharedMemorySpace = 3,
4141
/// Constant memory space identifier.
42-
kConstantMemorySpace = 4
42+
kConstantMemorySpace = 4,
43+
/// Tensor memory space identifier.
44+
/// Tensor memory is available only in arch-accelerated
45+
/// variants from sm100 onwards.
46+
kTensorMemorySpace = 6
4347
};
4448

4549
/// Return the element type and number of elements associated with a wmma matrix

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

+105
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ include "mlir/Interfaces/InferIntRangeInterface.td"
2323
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
2424
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
2525
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
26+
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
2627

2728
//===----------------------------------------------------------------------===//
2829
// NVVM dialect definitions
@@ -2592,6 +2593,110 @@ def NVVM_Breakpoint : NVVM_Op<"breakpoint"> {
25922593
let assemblyFormat = "attr-dict";
25932594
}
25942595

2596+
//===----------------------------------------------------------------------===//
2597+
// NVVM TCGEN05 Ops
2598+
//===----------------------------------------------------------------------===//
2599+
// Num CTAs in a group participating in the TCGEN05 operation.
2600+
// This corresponds to the "cta_group::1", "cta_group::2"
2601+
// modifiers in the PTX instructions.
2602+
def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
2603+
def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;
2604+
2605+
def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
2606+
"NVVM Tcgen05 group kind",
2607+
[Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
2608+
let genSpecializedAttr = 0;
2609+
let cppNamespace = "::mlir::NVVM";
2610+
}
2611+
def Tcgen05GroupKindAttr :
2612+
EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
2613+
let assemblyFormat = "`<` $value `>`";
2614+
}
2615+
2616+
def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
2617+
let summary = "Tcgen05 alloc operation";
2618+
let description = [{
2619+
The `tcgen05.alloc` Op allocates tensor core memory for
2620+
the amount specified by `nCols` and writes the destination
2621+
address to the `addr` argument. The `nCols` operand specifies the
2622+
number of columns to be allocated and it must be a power-of-two.
2623+
[For more information, refer to the PTX ISA]
2624+
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions)
2625+
}];
2626+
2627+
let arguments = (ins
2628+
AnyTypeOf<[LLVM_AnyPointer, LLVM_PointerShared]>:$addr,
2629+
I32:$nCols,
2630+
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
2631+
2632+
let assemblyFormat = "$addr `,` $nCols attr-dict `:` type(operands)";
2633+
2634+
let extraClassDeclaration = [{
2635+
static llvm::Intrinsic::ID
2636+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2637+
llvm::SmallVector<llvm::Value *> &args);
2638+
}];
2639+
string llvmBuilder = [{
2640+
llvm::SmallVector<llvm::Value *> args;
2641+
auto id = NVVM::Tcgen05AllocOp::getIntrinsicIDAndArgs(
2642+
*op, moduleTranslation, args);
2643+
createIntrinsicCall(builder, id, args);
2644+
}];
2645+
}
2646+
2647+
def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
2648+
let summary = "Tcgen05 dealloc operation";
2649+
let description = [{
2650+
The `tcgen05.dealloc` Op de-allocates the tensor core memory
2651+
specified by `tmemAddr`, which must be from a previous tensor
2652+
memory allocation. The `nCols` operand specifies the number
2653+
of columns to be de-allocated, and it must be a power-of-two.
2654+
[For more information, refer to the PTX ISA]
2655+
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions)
2656+
}];
2657+
2658+
let arguments = (ins LLVM_PointerTensor:$taddr, I32:$nCols,
2659+
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
2660+
2661+
let assemblyFormat = "$taddr `,` $nCols attr-dict `:` type(operands)";
2662+
2663+
let extraClassDeclaration = [{
2664+
static llvm::Intrinsic::ID
2665+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2666+
llvm::SmallVector<llvm::Value *> &args);
2667+
}];
2668+
string llvmBuilder = [{
2669+
llvm::SmallVector<llvm::Value *> args;
2670+
auto id = NVVM::Tcgen05DeallocOp::getIntrinsicIDAndArgs(
2671+
*op, moduleTranslation, args);
2672+
createIntrinsicCall(builder, id, args);
2673+
}];
2674+
}
2675+
2676+
def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit"> {
2677+
let summary = "Tcgen05 Op to relinquish the right to allocate";
2678+
let description = [{
2679+
The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
2680+
of the executing thread is relinquishing the right to allocate
2681+
Tensor Memory. So, it is illegal for a CTA to perform `tcgen05.alloc`
2682+
after any of its constituent threads execute `tcgen05.relinquish_alloc_permit`.
2683+
[For more information, refer to the PTX ISA]
2684+
(https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions)
2685+
}];
2686+
2687+
let arguments = (ins
2688+
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
2689+
2690+
let assemblyFormat = "attr-dict";
2691+
2692+
string llvmBuilder = [{
2693+
auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
2694+
llvm::Intrinsic::nvvm_tcgen05_relinq_alloc_permit_cg1 :
2695+
llvm::Intrinsic::nvvm_tcgen05_relinq_alloc_permit_cg2;
2696+
createIntrinsicCall(builder, id);
2697+
}];
2698+
}
2699+
25952700
//===----------------------------------------------------------------------===//
25962701
// NVVM target attribute.
25972702
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,47 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12431243
}
12441244
}
12451245

1246+
llvm::Intrinsic::ID
1247+
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
1248+
LLVM::ModuleTranslation &mt,
1249+
llvm::SmallVector<llvm::Value *> &args) {
1250+
auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1251+
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1252+
.getAddressSpace();
1253+
bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace;
1254+
bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1255+
1256+
llvm::Intrinsic::ID id;
1257+
if (isShared) {
1258+
id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1259+
: llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1260+
} else {
1261+
id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1262+
: llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1263+
}
1264+
1265+
// Fill the Intrinsic Args
1266+
args.push_back(mt.lookupValue(curOp.getAddr()));
1267+
args.push_back(mt.lookupValue(curOp.getNCols()));
1268+
1269+
return id;
1270+
}
1271+
1272+
llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
1273+
Operation &op, LLVM::ModuleTranslation &mt,
1274+
llvm::SmallVector<llvm::Value *> &args) {
1275+
auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1276+
auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1277+
? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1278+
: llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1279+
1280+
// Fill the Intrinsic Args
1281+
args.push_back(mt.lookupValue(curOp.getTaddr()));
1282+
args.push_back(mt.lookupValue(curOp.getNCols()));
1283+
1284+
return id;
1285+
}
1286+
12461287
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
12471288
/// have ConstantRangeAttr.
12481289
static void nvvmInferResultRanges(Operation *op, Value result,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
2+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-LLVM
3+
4+
// CHECK-LABEL: @llvm_nvvm_tcgen05_alloc
5+
llvm.func @llvm_nvvm_tcgen05_alloc(%addr : !llvm.ptr, %ncols : i32) {
6+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.alloc.cg1(ptr %{{.*}}, i32 %{{.*}})
7+
nvvm.tcgen05.alloc %addr, %ncols : !llvm.ptr, i32
8+
9+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.alloc.cg2(ptr %{{.*}}, i32 %{{.*}})
10+
nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr, i32
11+
llvm.return
12+
}
13+
14+
// CHECK-LABEL: @llvm_nvvm_tcgen05_alloc_shared
15+
llvm.func @llvm_nvvm_tcgen05_alloc_shared(%addr : !llvm.ptr<3>, %ncols : i32) {
16+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.alloc.shared.cg1(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
17+
nvvm.tcgen05.alloc %addr, %ncols : !llvm.ptr<3>, i32
18+
19+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.alloc.shared.cg2(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
20+
nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<3>, i32
21+
llvm.return
22+
}
23+
24+
// CHECK-LABEL: @llvm_nvvm_tcgen05_dealloc
25+
llvm.func @llvm_nvvm_tcgen05_dealloc(%addr : !llvm.ptr<6>, %ncols : i32) {
26+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.dealloc.cg1(ptr addrspace(6) %{{.*}}, i32 %{{.*}})
27+
nvvm.tcgen05.dealloc %addr, %ncols : !llvm.ptr<6>, i32
28+
29+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.dealloc.cg2(ptr addrspace(6) %{{.*}}, i32 %{{.*}})
30+
nvvm.tcgen05.dealloc %addr, %ncols {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>, i32
31+
llvm.return
32+
}
33+
34+
// CHECK-LABEL: @llvm_nvvm_tcgen05_relinquish_alloc_permit
35+
llvm.func @llvm_nvvm_tcgen05_relinquish_alloc_permit() {
36+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.relinq.alloc.permit.cg1()
37+
nvvm.tcgen05.relinquish_alloc_permit
38+
39+
// CHECK-LLVM: call void @llvm.nvvm.tcgen05.relinq.alloc.permit.cg2()
40+
nvvm.tcgen05.relinquish_alloc_permit {group = #nvvm.tcgen05_group<cta_2>}
41+
llvm.return
42+
}

0 commit comments

Comments
 (0)