Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support slice ops with default start #9923

Merged
merged 1 commit into from
Apr 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/xnnpack/operators/op_slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def define_node(
output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC]
dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice]

slice_begin_index = cast(int, node.args[2])
slice_begin_index = 0
if len(node.args) > 2 and node.args[2]:
slice_begin_index = cast(int, node.args[2])
if slice_begin_index < 0:
slice_begin_index = input_shape[dim_of_slice] + slice_begin_index

Expand Down
12 changes: 12 additions & 0 deletions backends/xnnpack/test/ops/test_slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ def forward(self, x):
# Note that two of the slices are optimized away as they are identity.
self._test_slice_copy(ConvSlice(), inputs, 4, 2)

def test_fp32_slice_copy_default_start(self):
"""
XNNPACK supports default start in slice op.
"""

class Slice(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.slice.Tensor(x, 0, None, 2)

inputs = (torch.randn(5, 5),)
self._test_slice_copy(Slice(), inputs, 1, 1)

def test_fp32_slice_copy_stride_non_1(self):
"""
XNNPACK does not support strided slicing.
Expand Down
Loading