From ce1f7b7ec462d381f1755ddcae5284a2729ea044 Mon Sep 17 00:00:00 2001 From: David Lerner Date: Sun, 4 May 2025 14:33:11 +0300 Subject: [PATCH] Let memref.collapse_shape implement ReifyRankedShapedTypeOpInterface. This MR implements ReifyRankedShapedTypeOpInterface for memref.collapse_shape and adds support in reifyResultShapes for memref.dim to operate directly on shaped values, eliminating reliance on collapse_shape. The new logic fully supports all collapse sizes and reifies dynamic dimensions, improving shape inference and lowering fidelity. --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 3 +- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 31 +++++++ mlir/test/Dialect/MemRef/resolve-dim-ops.mlir | 91 +++++++++++++++++++ 3 files changed, 124 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index d6d8161d3117b..e401e3e8a53ae 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1761,7 +1761,8 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ } def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>{ let summary = "operation to produce a memref with a smaller rank."; let description = [{ The `memref.collapse_shape` op produces a new view with a smaller rank diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 6f10a31c15626..177b4a69d256f 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2482,6 +2482,37 @@ MemRefType CollapseShapeOp::computeCollapsedType( srcType.getMemorySpace()); } +// This method handles groups of dimensions where at least one dimension is dynamic. +// For each such group, it computes the combined size by multiplying all the sizes +// of the dimensions in that group. These computed sizes are then used to describe +// the resulting shape after collapsing. +LogicalResult CollapseShapeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) { + SmallVector reassociationArray = + getReassociationIndices(); + Value source = getSrc(); + Location loc = getLoc(); + SmallVector dynamicValues; + auto resultShape = cast(getResultType()).getShape(); + auto sourceShape = cast(source.getType()).getShape(); + for (auto group : reassociationArray) { + if (!llvm::any_of(group, [&](int64_t dim) { + return ShapedType::isDynamic(sourceShape[dim]); + })) + continue; + Value resultVal = builder.create(loc, source, group[0]); + for (auto dim : llvm::drop_begin(group)) { + Value nextVal = builder.create(loc, source, dim); + resultVal = builder.create(loc, resultVal, nextVal); + } + + dynamicValues.push_back(resultVal); + } + + reifiedResultShapes = {getMixedValues(resultShape, dynamicValues, builder)}; + return success(); +} + void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir index e354eb91d7557..f40b0ad849fa0 100644 --- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir +++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir @@ -97,3 +97,94 @@ func.func @iter_to_init_arg_loop_like( } return %result : tensor } + +// ----- + +// CHECK-LABEL: func.func @collapse_dynamic_with_unit_dims( +// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x1xsi8>) -> index { +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x1xsi8> +// CHECK: return %[[dim]] : index +// CHECK: } +func.func @collapse_dynamic_with_unit_dims (%arg0: memref<1x32x?x1xsi8>) + -> index { + %c2 = arith.constant 2 : index + %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x1xsi8> into memref<1x32x?xsi8> + %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8> + return %dim_3: index +} + +// ----- + +// CHECK-LABEL: func.func @fold_dynamic_and_const_with_dynamic_on_right( +// CHECK-SAME: %[[arg0:.*]]: memref<1x32x8x?xsi8>) -> index { +// CHECK: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x8x?xsi8> +// CHECK: %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index +// CHECK: return %[[res]] : index +// CHECK: } +func.func @fold_dynamic_and_const_with_dynamic_on_right(%arg0: memref<1x32x8x?xsi8>) + -> index { + %c2 = arith.constant 2 : index + %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x8x?xsi8> into memref<1x32x?xsi8> + %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8> + return %dim_3: index +} + +// ----- + +// CHECK-LABEL: func.func @fold_dynamic_and_const_with_dynamic_on_left( +// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x8xsi8>) -> index { +// CHECK: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x8xsi8> +// CHECK: %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index +// CHECK: return %[[res]] : index +// CHECK: } +func.func @fold_dynamic_and_const_with_dynamic_on_left(%arg0: memref<1x32x?x8xsi8>) + -> index { + %c2 = arith.constant 2 : index + %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x8xsi8> into memref<1x32x?xsi8> + %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8> + return %dim_3: index +} + +// ----- + +// CHECK-LABEL: func.func @fold_more_than_two_elements_group( +// CHECK-SAME: %[[arg0:.*]]: memref<2x32x?x8xsi8>) -> index { +// CHECK: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[c64:.*]] = arith.constant 64 : index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<2x32x?x8xsi8> +// CHECK: %[[res0:.*]] = arith.muli %[[dim]], %[[c64]] : index +// CHECK: %[[res1:.*]] = arith.muli %[[res0]], %[[c8]] : index +// CHECK: return %[[res1]] : index +// CHECK: } +func.func @fold_more_than_two_elements_group(%arg0: memref<2x32x?x8xsi8>) + -> index { + %c1 = arith.constant 0 : index + %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<2x32x?x8xsi8> into memref + %dim_3 = memref.dim %collapse_shape, %c1 : memref + return %dim_3: index +} + +// ----- + +// CHECK-LABEL: func.func @fold_group_with_two_dynamic( +// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x?xsi8>) -> index { +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[dim2:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x?xsi8> +// CHECK: %[[dim3:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x?x?xsi8> +// CHECK: %[[res:.*]] = arith.muli %[[dim2]], %[[dim3]] : index +// CHECK: return %[[res]] : index +// CHECK: } +func.func @fold_group_with_two_dynamic(%arg0: memref<1x32x?x?xsi8>) + -> index { + %c2 = arith.constant 2 : index + %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x?xsi8> into memref<1x32x?xsi8> + %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8> + return %dim_3: index +}