Skip to content

Commit

Permalink
[DT][GPU] Implement EncodingLayoutAttrInterface
Browse files Browse the repository at this point in the history
Signed-off-by: Jorn Tuyls <[email protected]>
  • Loading branch information
jtuyls committed Feb 20, 2025
1 parent fb3523b commit 4f00a07
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,18 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
} else if (isROCMBackend(targetAttr)) {
LDBG("Select GPUEncodingLayoutAttr attribute as the layout attribute.");
layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
IREE::GPU::GPUEncodingLayoutAttr::get(ctx,
getGPUTargetAttr(targetAttr)));
IREE::GPU::GPUEncodingLayoutAttr::get(
ctx, DictionaryAttr::get(
ctx, NamedAttribute(kGPUTargetAttrName,
getGPUTargetAttr(targetAttr)))));
} else if (testCLGPUTarget) {
LDBG("Select GPUEncodingLayoutAttr attribute as the layout attribute. "
"(testCLGPUTarget)");
layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
IREE::GPU::GPUEncodingLayoutAttr::get(ctx, getCLGPUTarget(ctx)));
IREE::GPU::GPUEncodingLayoutAttr::get(
ctx,
DictionaryAttr::get(ctx, NamedAttribute(kGPUTargetAttrName,
getCLGPUTarget(ctx)))));
} else {
LDBG("Select EncodingNopLayoutAttr attribute as the layout attribute.");
layoutAttr = IREE::Codegen::EncodingNopLayoutAttr::get(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1190,3 +1190,37 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16() {
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_BF16, intrinsics_m = 8, intrinsics_n = 2, subgroups_n = 4, intrinsics_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]

// -----

//----------------------------------------------------------------------------//
// Test suite for encodings with resolved layouts.
//----------------------------------------------------------------------------//

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip"}>
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], layouts = [#iree_gpu.gpu_encoding_layout<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [128, 16], outerDimsPerm = [0, 1]}}>]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
func.func @set_encoding_with_layout() attributes {
hal.executable.target = #executable_target_rocm_hsaco_fb
} {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<255x513xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #encoding>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<255x513xf32>> -> tensor<255x513xf32>
%3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding>
flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #encoding>>
return
}
// CHECK-LABEL: func.func @set_encoding_with_layout
// CHECK-DAG: %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<255x513xf32>>
// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x33x128x16xf32>
// CHECK: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]]
// CHECK: %[[PACK:.+]] = linalg.pack %[[INPUT]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [0, 1]
// CHECK-SAME: inner_tiles = [128, 16]
// CHECK-SAME: tensor<255x513xf32> -> tensor<2x33x128x16xf32>
// CHECK: flow.dispatch.tensor.store %[[PACK]], %[[RESULT_BINDING]]
Original file line number Diff line number Diff line change
Expand Up @@ -2959,7 +2959,7 @@ func.func @broadcast_K() attributes {
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
#executable_target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz", cpu_features = "+avx512f", encoding = #iree_cpu.cpu_encoding_layout<>}>
#executable_target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz", cpu_features = "+avx512f"}>
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], layouts = [#iree_cpu.cpu_encoding_layout<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [1, 1], outerDimsPerm = [0, 1]}}>]>
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,11 @@ getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) {
MaterializeEncodingInfo
getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK) {
MaterializeEncodingInfo encodingInfo;
auto cDims = getEncodingContractionDims(encoding);
FailureOr<linalg::ContractionDimensions> cDims =
getEncodingContractionDims(encoding);
if (failed(cDims)) {
return encodingInfo;
}
// The following expects M, N, K, and Batch sizes of at most 1 for now
assert(cDims->m.size() <= 1 && cDims->n.size() <= 1 && cDims->k.size() == 1 &&
cDims->batch.size() <= 1 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,8 @@ def IREEGPU_GPUEncodingLayoutAttr :
let assemblyFormat = "`<` struct(params) `>`";

let parameters = (ins
OptionalParameter<"::mlir::iree_compiler::IREE::GPU::TargetAttr",
"IREE GPU target attribute. It is expected to be used in a pass scope, "
"but not the final IR output.">:$targetAttr
OptionalParameter<"DictionaryAttr", "Executable target configuration. It is "
"expected to be used in a pass scope, but not the final IR output.">:$configuration
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,6 @@ namespace {
// Utilities.
//===----------------------------------------------------------------------===//

/// Appends the NamedAttribute into `config` if there is a `name` NamedAttribute
/// in the `dictAttr`.
static void storeNamedAttrIfPresent(SmallVectorImpl<NamedAttribute> &config,
DictionaryAttr dictAttr, StringRef name) {
auto attr = dictAttr.getNamed(name);
if (!attr) {
return;
}
config.push_back(attr.value());
}

static void transposeInPlace(MaterializeEncodingInfo &info) {
// Vector cases: nothing to do.
if (info.innerTileSizes.size() < 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/ExternalInterfaces/Utils.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -303,6 +305,8 @@ struct GPUDeviceEncodingLayoutResolverAttrInterface
MaterializeEncodingInfo getEncodingInfo(Attribute attr,
RankedTensorType type) const {
auto layoutAttr = cast<GPUEncodingLayoutAttr>(attr);
DictionaryAttr config = layoutAttr.getConfiguration();

auto encoding = llvm::dyn_cast_or_null<IREE::Encoding::EncodingAttr>(
type.getEncoding());

Expand All @@ -311,8 +315,24 @@ struct GPUDeviceEncodingLayoutResolverAttrInterface
return info;
}

// If the layout is already resolved, use it directly.
if (config) {
if (std::optional<NamedAttribute> namedAttr =
config.getNamed(kEncodingInfoAttrName)) {
std::optional<MaterializeEncodingInfo> preresolvedInfo =
Codegen::deserializeEncodingInfo(
cast<DictionaryAttr>(namedAttr->getValue()));
assert(preresolvedInfo && "encoding_info is invalid");
return preresolvedInfo.value();
}
}

IREE::GPU::TargetAttr gpuAttr = getGPUTargetAttr(config);
if (!gpuAttr) {
return info;
}
DataTiledMMAAttr mma = chooseDataTiledMMAAttr(
encoding.getElementTypesArray(), layoutAttr.getTargetAttr(), encoding);
encoding.getElementTypesArray(), gpuAttr, encoding);
if (!mma) {
return info;
}
Expand All @@ -336,8 +356,37 @@ struct GPUDeviceEncodingLayoutResolverAttrInterface
if (!linalgOp) {
return nullptr;
}
DictionaryAttr config = layoutAttr.getConfiguration();
IREE::GPU::TargetAttr gpuAttr = getGPUTargetAttr(config);
if (!gpuAttr) {
return nullptr;
}
return lowerContractionOpToMultiMmaOp(b, linalgOp, convertedOperands,
layoutAttr.getTargetAttr());
gpuAttr);
}
};

struct GPUHostEncodingLayoutResolverAttrInterface final
: IREE::Encoding::EncodingLayoutResolverAttrInterface::ExternalModel<
GPUHostEncodingLayoutResolverAttrInterface, GPUEncodingLayoutAttr> {
Attribute cloneWithSimplifiedConfig(Attribute attr,
DictionaryAttr config) const {
MLIRContext *ctx = attr.getContext();
SmallVector<NamedAttribute> configItems;
DictionaryAttr existingConfig =
cast<GPUEncodingLayoutAttr>(attr).getConfiguration();
if (existingConfig) {
configItems.append(existingConfig.getValue().begin(),
existingConfig.getValue().end());
}
storeNamedAttrIfPresent(configItems, config, kGPUTargetAttrName);
return GPUEncodingLayoutAttr::get(ctx,
DictionaryAttr::get(ctx, configItems));
}

Attribute getLayout(Attribute attr, RankedTensorType type) const {
MLIRContext *ctx = attr.getContext();
return GPUEncodingLayoutAttr::get(ctx, getLayoutImpl(attr, type));
}
};

Expand Down Expand Up @@ -435,7 +484,8 @@ void registerGPUEncodingExternalModels(DialectRegistry &registry) {
registry.addExtension(
+[](MLIRContext *ctx, IREE::GPU::IREEGPUDialect *dialect) {
IREE::GPU::GPUEncodingLayoutAttr::attachInterface<
GPUDeviceEncodingLayoutResolverAttrInterface>(*ctx);
GPUDeviceEncodingLayoutResolverAttrInterface,
GPUHostEncodingLayoutResolverAttrInterface>(*ctx);
IREE::GPU::GPUPadLayoutAttr::attachInterface<
GPUPadEncodingLayoutResolverAttrInterface>(*ctx);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,13 @@ DictionaryAttr getLayoutImpl(Attribute attr, RankedTensorType type) {
ctx, NamedAttribute(kEncodingInfoAttrName, encodingInfoAttr));
}

void storeNamedAttrIfPresent(SmallVectorImpl<NamedAttribute> &config,
DictionaryAttr dictAttr, StringRef name) {
auto attr = dictAttr.getNamed(name);
if (!attr) {
return;
}
config.push_back(attr.value());
}

} // namespace mlir::iree_compiler::IREE
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ Value calculateStorageSizeInBytesImpl(Attribute attr, Location loc,
/// Requirement: `attr` must implement IREE::Codegen::LayoutAttrInterface.
DictionaryAttr getLayoutImpl(Attribute attr, RankedTensorType type);

/// Appends the NamedAttribute into `config` if there is a `name` NamedAttribute
/// in the `dictAttr`.
void storeNamedAttrIfPresent(SmallVectorImpl<NamedAttribute> &config,
DictionaryAttr dictAttr, StringRef name);

} // namespace mlir::iree_compiler::IREE

#endif // IREE_COMPILER_CODEGEN_EXTERNALINTERFACES_UTILSS_H_
23 changes: 18 additions & 5 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,12 +967,25 @@ IREE::GPU::TargetAttr getCLGPUTarget(MLIRContext *context) {
return IREE::GPU::getFullTarget(backend, arch, features, context);
}

IREE::GPU::TargetAttr getGPUTargetAttr(IREE::HAL::ExecutableTargetAttr target) {
if (auto config = target.getConfiguration()) {
if (auto attr = config.getAs<IREE::GPU::TargetAttr>(kGPUTargetAttrName))
return attr;
IREE::GPU::TargetAttr getGPUTargetAttr(Attribute attr) {
if (!attr) {
return {};
}
return getCLGPUTarget(target.getContext());
DictionaryAttr config;
auto targetAttr = dyn_cast<IREE::HAL::ExecutableTargetAttr>(attr);
if (targetAttr) {
config = targetAttr.getConfiguration();
} else {
config = dyn_cast<DictionaryAttr>(attr);
}
if (!config) {
return getCLGPUTarget(attr.getContext());
}
auto gpuAttr = config.getAs<IREE::GPU::TargetAttr>(kGPUTargetAttrName);
if (!gpuAttr) {
return getCLGPUTarget(attr.getContext());
}
return gpuAttr;
}

IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op) {
Expand Down
11 changes: 6 additions & 5 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,15 @@ FailureOr<ArrayAttr> getSupportedMmaTypes(DictionaryAttr config);
FailureOr<ArrayAttr> getSupportedMmaTypes(mlir::FunctionOpInterface entryPoint);

/// Returns the GPU target attribute from `iree-gpu-test-target` if provided.
/// Returns null TargetAttr othersise.
/// Returns null TargetAttr otherwise.
IREE::GPU::TargetAttr getCLGPUTarget(MLIRContext *context);

/// Returns the GPU target attribute from executable |target| if found.
/// Returns null TargetAttr othersise.
IREE::GPU::TargetAttr getGPUTargetAttr(IREE::HAL::ExecutableTargetAttr target);
/// Returns the GPU target attribute from attribute `attr` if found. The `attr`
/// can be either IREE::HAL::ExecutableTargetAttr or DictionaryAttr.
/// Returns null TargetAttr otherwise.
IREE::GPU::TargetAttr getGPUTargetAttr(Attribute attr);
/// Returns the GPU target attribute from the executable target wrapping |op|
/// if found. Returns null TargetAttr othersise.
/// if found. Returns null TargetAttr otherwise.
IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op);

/// Returns the GPU subgroup size chosen for the current CodeGen pipeline if
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,51 @@ util.func public @tensor_sizeof(%d0: index, %d1: index) -> (index, index) {

// -----

//------------------------------------------------------------------------------
// #iree_gpu.gpu_encoding_layout specialization tests.
// These get serialized to the layout attributes.
//------------------------------------------------------------------------------

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
{
abi = "hip",
encoding = #iree_gpu.gpu_encoding_layout<>,
iree.gpu.target = #iree_gpu.target<arch = "gfx942",
features = "",
wgp = <compute = fp32,
storage = b32,
subgroup = none,
dot = none,
mma = [<MFMA_F32_16x16x4_F32>],
subgroup_size_choices = [64],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128,
simds_per_wgp = 4,
vgpr_space_bits = 16384>>
}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_rocm_hsaco_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>

util.global private @device_a = #device_target_local_0_
util.func public @gpu_with_encoding_layout(%d0: index, %d1: index) -> index {
%size0 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?xf32, #encoding>{%d0, %d1} : index
util.return %size0 : index
}
// CHECK: #[[$ENCODING:.+]] = #iree_encoding.encoding
// CHECK-SAME: #iree_gpu.gpu_encoding_layout
// CHECK-SAME: encoding_info = {innerDimsPos = [{{.+}}], innerTileSizes = [{{.+}}], outerDimsPerm = [{{.+}}]}
// CHECK-LABEL: util.func public @gpu_with_encoding_layout
// CHECK: %[[RES:.+]] = stream.tensor.sizeof {{.+}} tensor<?x?xf32, #[[$ENCODING]]>
// CHECK: return %[[RES]]

// -----

//------------------------------------------------------------------------------
// iree_gpu.gpu_pad_encoding specialization tests.
// These get serialized to iree_encoding.pad_encoding_layout attributes.
Expand Down

0 comments on commit 4f00a07

Please sign in to comment.