diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index d6d8161d3117b..e401e3e8a53ae 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1761,7 +1761,8 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ } def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>{ let summary = "operation to produce a memref with a smaller rank."; let description = [{ The `memref.collapse_shape` op produces a new view with a smaller rank diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index a0237c18cf2fe..95805204be3f7 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2481,6 +2481,37 @@ MemRefType CollapseShapeOp::computeCollapsedType( srcType.getMemorySpace()); } +// This method handles groups of dimensions where at least one dimension is dynamic. +// For each such group, it computes the combined size by multiplying all the sizes +// of the dimensions in that group. These computed sizes are then used to describe +// the resulting shape after collapsing. +LogicalResult CollapseShapeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) { + SmallVector reassociationArray = + getReassociationIndices(); + Value source = getSrc(); + Location loc = getLoc(); + SmallVector dynamicValues; + auto resultShape = cast(getResultType()).getShape(); + auto sourceShape = cast(source.getType()).getShape(); + for (auto group : reassociationArray) { + if (!llvm::any_of(group, [&](int64_t dim) { + return ShapedType::isDynamic(sourceShape[dim]); + })) + continue; + Value resultVal = builder.create(loc, source, group[0]); + for (auto dim : llvm::drop_begin(group)) { + Value nextVal = builder.create(loc, source, dim); + resultVal = builder.create(loc, resultVal, nextVal); + } + + dynamicValues.push_back(resultVal); + } + + reifiedResultShapes = {getMixedValues(resultShape, dynamicValues, builder)}; + return success(); +} + void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir index e354eb91d7557..f40b0ad849fa0 100644 --- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir +++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir @@ -97,3 +97,94 @@ func.func @iter_to_init_arg_loop_like( } return %result : tensor } + +// ----- + +// CHECK-LABEL: func.func @collapse_dynamic_with_unit_dims( +// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x1xsi8>) -> index { +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x1xsi8> +// CHECK: return %[[dim]] : index +// CHECK: } +func.func @collapse_dynamic_with_unit_dims (%arg0: memref<1x32x?x1xsi8>) + -> index { + %c2 = arith.constant 2 : index + %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x1xsi8> into memref<1x32x?xsi8> + %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8> + return %dim_3: index +} + +// ----- + +// CHECK-LABEL: func.func @fold_dynamic_and_const_with_dynamic_on_right( +// CHECK-SAME: %[[arg0:.*]]: memref<1x32x8x?xsi8>) -> index { +// CHECK: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x8x?xsi8> +// CHECK: %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index +// CHECK: return %[[res]] : index +// CHECK: } +func.func @fold_dynamic_and_const_with_dynamic_on_right(%arg0: memref<1x32x8x?xsi8>) + -> index { + %c2 = arith.constant 2 : index + %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x8x?xsi8> into memref<1x32x?xsi8> + %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8> + return %dim_3: index +} + +// ----- + +// CHECK-LABEL: func.func @fold_dynamic_and_const_with_dynamic_on_left( +// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x8xsi8>) -> index { +// CHECK: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x8xsi8> +// CHECK: %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index +// CHECK: return %[[res]] : index +// CHECK: } +func.func @fold_dynamic_and_const_with_dynamic_on_left(%arg0: memref<1x32x?x8xsi8>) + -> index { + %c2 = arith.constant 2 : index + %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x8xsi8> into memref<1x32x?xsi8> + %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8> + return %dim_3: index +} + +// ----- + +// CHECK-LABEL: func.func @fold_more_than_two_elements_group( +// CHECK-SAME: %[[arg0:.*]]: memref<2x32x?x8xsi8>) -> index { +// CHECK: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[c64:.*]] = arith.constant 64 : index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<2x32x?x8xsi8> +// CHECK: %[[res0:.*]] = arith.muli %[[dim]], %[[c64]] : index +// CHECK: %[[res1:.*]] = arith.muli %[[res0]], %[[c8]] : index +// CHECK: return %[[res1]] : index +// CHECK: } +func.func @fold_more_than_two_elements_group(%arg0: memref<2x32x?x8xsi8>) + -> index { + %c1 = arith.constant 0 : index + %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<2x32x?x8xsi8> into memref + %dim_3 = memref.dim %collapse_shape, %c1 : memref + return %dim_3: index +} + +// ----- + +// CHECK-LABEL: func.func @fold_group_with_two_dynamic( +// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x?xsi8>) -> index { +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[dim2:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x?xsi8> +// CHECK: %[[dim3:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x?x?xsi8> +// CHECK: %[[res:.*]] = arith.muli %[[dim2]], %[[dim3]] : index +// CHECK: return %[[res]] : index +// CHECK: } +func.func @fold_group_with_two_dynamic(%arg0: memref<1x32x?x?xsi8>) + -> index { + %c2 = arith.constant 2 : index + %collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x?xsi8> into memref<1x32x?xsi8> + %dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8> + return %dim_3: index +}