diff --git a/backends/xnnpack/operators/op_slice_copy.py b/backends/xnnpack/operators/op_slice_copy.py index 40d8e5f04eb..d9056afa832 100644 --- a/backends/xnnpack/operators/op_slice_copy.py +++ b/backends/xnnpack/operators/op_slice_copy.py @@ -69,7 +69,9 @@ def define_node( output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC] dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice] - slice_begin_index = cast(int, node.args[2]) + slice_begin_index = 0 + if len(node.args) > 2 and node.args[2]: + slice_begin_index = cast(int, node.args[2]) if slice_begin_index < 0: slice_begin_index = input_shape[dim_of_slice] + slice_begin_index diff --git a/backends/xnnpack/test/ops/test_slice_copy.py b/backends/xnnpack/test/ops/test_slice_copy.py index ea65571b1e8..857c78480ad 100644 --- a/backends/xnnpack/test/ops/test_slice_copy.py +++ b/backends/xnnpack/test/ops/test_slice_copy.py @@ -69,6 +69,18 @@ def forward(self, x): # Note that two of the slices are optimized away as they are identity. self._test_slice_copy(ConvSlice(), inputs, 4, 2) + def test_fp32_slice_copy_default_start(self): + """ + XNNPACK supports default start in slice op. + """ + + class Slice(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.slice.Tensor(x, 0, None, 2) + + inputs = (torch.randn(5, 5),) + self._test_slice_copy(Slice(), inputs, 1, 1) + def test_fp32_slice_copy_stride_non_1(self): """ XNNPACK does not support strided slicing.