Skip to content

Commit

Permalink
Implement lowering of torch.aten.hstack (#3563)
Browse files Browse the repository at this point in the history
  • Loading branch information
BaneTrifa authored Sep 11, 2024
1 parent 0474082 commit 1c4b9d6
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 0 deletions.
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14121,6 +14121,29 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [
}];
}

def Torch_AtenHstackOp : Torch_Op<"aten.hstack", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::hstack : (Tensor[]) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenHstackOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenHstackOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
AllowsTypeRefinement
]> {
Expand Down
52 changes: 52 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10639,6 +10639,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.hstack\"(%arg0: !torch.list<list<int>>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" torch.prim.Loop %1, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %7 = func.call @\"__torch_mlir_shape_fn.aten.atleast_1d\"(%6) : (!torch.list<int>) -> !torch.list<int>\n"
" %8 = torch.aten.append.t %0, %7 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %2 = torch.aten.__getitem__.t %0, %int0 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %3 = torch.aten.len.t %2 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
" %6 = func.call @__torch__.torch.jit._shape_functions.cat(%0, %int0) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %6 : !torch.list<int>\n"
" } else {\n"
" %6 = func.call @__torch__.torch.jit._shape_functions.cat(%0, %int1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %6 : !torch.list<int>\n"
" }\n"
" return %5 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
Expand Down Expand Up @@ -15185,6 +15210,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.hstack\"(%arg0: !torch.list<tuple<int, int>>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %2 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" torch.prim.Loop %4, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
Expand Down
53 changes: 53 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3813,6 +3813,58 @@ class DecomposeAtenStackOp : public OpRewritePattern<AtenStackOp> {
};
} // namespace

// Decompose `aten.hstack` into `aten.at_least1d` and `aten.cat`.
// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L3908
namespace {
class DecomposeAtenHstackOp : public OpRewritePattern<AtenHstackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenHstackOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

// Get SmallVector<Value> from Value.
SmallVector<Value> tensors;
if (!getListConstructElements(op.getTensors(), tensors))
return rewriter.notifyMatchFailure(
op, "unimplemented: the tensor list is not from list construct");

// Execute AtenAtleast1dOp on every tensor inside tensors.
SmallVector<Value> atleast1dTensors;
for (auto tensor : tensors) {
std::optional<unsigned> tensorRank = getTensorRank(tensor);

// Check if the tensor is already of rank >= 1.
if (*tensorRank < 1) {
auto atleast1dTensor =
rewriter.create<AtenAtleast1dOp>(loc, tensor.getType(), tensor);
atleast1dTensors.push_back(atleast1dTensor);
} else {
atleast1dTensors.push_back(tensor);
}
}

// Make Value list from atleast1dTensors variable.
auto elemType = cast<BaseTensorType>(atleast1dTensors[0].getType())
.getWithSizesAndDtype(std::nullopt, nullptr);
Value atleast1dTensorList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(elemType), atleast1dTensors);

// Replace hstack with cat operator.
if (getTensorRank(atleast1dTensors[0]) == 1)
rewriter.replaceOpWithNewOp<AtenCatOp>(
op, op.getType(), atleast1dTensorList,
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0)));
else
rewriter.replaceOpWithNewOp<AtenCatOp>(
op, op.getType(), atleast1dTensorList,
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)));

return success();
}
};
} // namespace

// Decompose aten.roll into aten.slice and aten.cat ops.
// https://pytorch.org/docs/stable/generated/torch.roll.html
namespace {
Expand Down Expand Up @@ -9567,6 +9619,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHstackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatInterleaveSelfIntOp>(
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenOnesLikeOp>();
target.addIllegalOp<AtenZerosLikeOp>();
target.addIllegalOp<AtenStackOp>();
target.addIllegalOp<AtenHstackOp>();
target.addIllegalOp<AtenRollOp>();
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
Expand Down
13 changes: 13 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,10 @@
"GridSamplerBasic4_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorStaticModule_basic",
"IntFloatModule_basic",
Expand Down Expand Up @@ -2215,6 +2219,11 @@
# failed to legalize operation 'torch.aten.rrelu_with_noise'
"ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic",
# incompatible return type failure for tosa.concat.
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
# Shape Related failures
"PrimListUnpackNumMismatchModule_basic",
"ReshapeExpandModule_basic",
Expand Down Expand Up @@ -2623,6 +2632,10 @@
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HardtanhBackward_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,19 @@ def aten〇atleast_2d〡shape(self: List[int]) -> List[int]:
def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
return upstream_shape_functions.stack(tensors, dim)


@check_shape_function([
Invocation([LongTensorOfShape(2, 4, 3), LongTensorOfShape(2, 5, 3)]), # Basic case.
])
def aten〇hstack〡shape(tensors: List[List[int]]) -> List[int]:

tensors_atleast1d = [aten〇atleast_1d〡shape(tensor) for tensor in tensors]

if len(tensors_atleast1d[0]) == 1:
return upstream_shape_functions.cat(tensors_atleast1d, dim=0)

return upstream_shape_functions.cat(tensors_atleast1d, dim=1)

def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]:
return self

Expand Down Expand Up @@ -5325,6 +5338,23 @@ def aten〇atleast_2d〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation([NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int32), NonZeroDTensorWithDtype(torch.int64)]),
Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]),
Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]),
Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32),
NonZeroDTensorWithDtype(torch.complex64)])])
def aten〇hstack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int:
ranks: List[Optional[int]] = []
dtypes: List[int] = []
assert len(tensors_rank_dtype) != 0
for tensor_rank_dtype in tensors_rank_dtype:
tensor_rank, tensor_dtype = tensor_rank_dtype
ranks.append(tensor_rank)
dtypes.append(tensor_dtype)

return promote_dtypes(ranks, dtypes)

@check_dtype_function(
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
TensorOfShape(1, dtype=torch.int32)]),])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,7 @@ def emit_with_mutating_variants(key, **kwargs):
has_folder=True,
)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::hstack : (Tensor[]) -> (Tensor)")
emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
Expand Down
101 changes: 101 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,107 @@ def TensorsStackPromoteDTypeModule_basic(module, tu: TestUtils):
# ==============================================================================


class HstackBasicIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 3, 4], torch.bool, True),
([2, 3, 4], torch.int32, True),
([2, 3, 4], torch.int64, True),
]
)
def forward(self, x, y, z):
return torch.ops.aten.hstack([x, y, z])


@register_test_case(module_factory=lambda: HstackBasicIntModule())
def HstackBasicIntModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(2, 3, 4, low=0, high=2).bool(),
tu.randint(2, 3, 4, low=0, high=100).int(),
tu.randint(2, 3, 4, low=0, high=100).long(),
)


class HstackBasicFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 6, 4], torch.int32, True),
([2, 3, 4], torch.float64, True),
]
)
def forward(self, x, y):
return torch.ops.aten.hstack([x, y])


@register_test_case(module_factory=lambda: HstackBasicFloatModule())
def HstackBasicFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 6, 4).int(),
tu.rand(2, 3, 4).double(),
)


class HstackBasicIntFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int32, True),
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, x, y):
return torch.ops.aten.hstack([x, y])


@register_test_case(module_factory=lambda: HstackBasicIntFloatModule())
def HstackBasicIntFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, 6, 4, 2, low=1, high=50).int(),
tu.rand(4, 3, 4, 2),
)


class HstackBasicComplexModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.complex64, True),
([-1, -1, -1, -1], torch.complex128, True),
]
)
def forward(self, x, y):
return torch.ops.aten.hstack([x, y])


@register_test_case(module_factory=lambda: HstackBasicComplexModule())
def HstackBasicComplexModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(4, 6, 4, 2).type(torch.complex64),
tu.rand(4, 3, 4, 2).type(torch.complex128),
)


# ==============================================================================


class GatherModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 1c4b9d6

Please sign in to comment.