Skip to content

Commit bfe6e5c

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Support slice ops with default start
Summary: Since D71962884, we see the following slice ops in ASR encoder: {F1976830836} This is causing failure during XNNPack delegation, since XNNPack slice pass is trying to compare start_idx 'None' to 0. This diff fixes that. Differential Revision: D72503552
1 parent b7eee0c commit bfe6e5c

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

backends/xnnpack/operators/op_slice_copy.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def define_node(
6969
output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC]
7070
dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice]
7171

72-
slice_begin_index = cast(int, node.args[2])
72+
slice_begin_index = 0
73+
if len(node.args) > 2 and node.args[2]:
74+
slice_begin_index = cast(int, node.args[2])
7375
if slice_begin_index < 0:
7476
slice_begin_index = input_shape[dim_of_slice] + slice_begin_index
7577

backends/xnnpack/test/ops/test_slice_copy.py

+13
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,19 @@ def forward(self, x):
6969
# Note that two of the slices are optimized away as they are identity.
7070
self._test_slice_copy(ConvSlice(), inputs, 4, 2)
7171

72+
def test_fp32_slice_copy_default_start(self):
73+
"""
74+
XNNPACK supports default start in slice op.
75+
"""
76+
77+
class Slice(torch.nn.Module):
78+
def forward(self, x):
79+
return torch.ops.aten.slice.Tensor(x, 0, None, 2)
80+
81+
inputs = (torch.randn(5, 5),)
82+
self._test_slice_copy(Slice(), inputs, 1, 1)
83+
84+
7285
def test_fp32_slice_copy_stride_non_1(self):
7386
"""
7487
XNNPACK does not support strided slicing.

0 commit comments

Comments
 (0)