diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 498b97503e276a..897af30b09a087 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -797,8 +797,14 @@ bool index_tensor_on_list(ov::pass::NodeRegistry& rg, if (id_dtype == element::boolean || id_dtype == element::u8) { auto idx = rg.make(indices[i], element::u8); auto nonzero = rg.make(idx); - auto input_order = rg.make(element::i32, Shape{2}, std::vector{1, 0}); - auto masked_id = rg.make(nonzero, input_order); + Output masked_id; + if (indices.size() == 1) { + auto input_order = rg.make(element::i32, Shape{2}, std::vector{1, 0}); + masked_id = rg.make(nonzero, input_order); + } else { + auto zero_const = rg.make(element::i32, Shape{1}, 0); + masked_id = rg.make(nonzero, zero_const); + } masked_indicies.push_back(masked_id); is_masked_bool.push_back(true); } else { @@ -815,7 +821,7 @@ bool index_tensor_on_list(ov::pass::NodeRegistry& rg, return true; } // perform gather for single element case - if (advanced_ids.size() == 1) { + if (advanced_ids.size() == 1 && advanced_ids[0] == 0) { auto index = masked_indicies[advanced_ids[0]]; if (is_masked_bool[advanced_ids[0]]) { auto gather = rg.make(data, index); diff --git a/tests/layer_tests/pytorch_tests/test_index.py b/tests/layer_tests/pytorch_tests/test_index.py index a085f6c1c8fa21..82a64d4da96451 100644 --- a/tests/layer_tests/pytorch_tests/test_index.py +++ b/tests/layer_tests/pytorch_tests/test_index.py @@ -9,44 +9,47 @@ class TestIndex(PytorchLayerTest): - def _prepare_input(self, input_shape, idx): - import numpy as np - return (np.random.randn(*input_shape).astype(np.float32), idx) + def _prepare_input(self, input_shape, idx=None): + rng = np.random.default_rng(42) + x = rng.standard_normal(size=input_shape, dtype=np.float32) + return (x,) if idx is None else (x, idx) def create_model(self, model="list"): - import torch - class aten_index_list(torch.nn.Module): - def forward(self, x, idx): return x[idx] class aten_index_getitem(torch.nn.Module): - def forward(self, x, idx): return x.__getitem__(idx) class aten_index_list_bool(torch.nn.Module): - def forward(self, x, idx): return x[idx.to(torch.bool)] class aten_index_getitem_bool(torch.nn.Module): - def forward(self, x, idx): return x.__getitem__(idx.to(torch.bool)) + + class aten_index_bool_with_axis(torch.nn.Module): + def __init__(self): + super().__init__() + self.idx = torch.tensor([1, 0, 1, 0, 1], dtype=torch.bool) + + def forward(self, x): + return x[:,:,self.idx] + cases = { "list": aten_index_list, "getitem": aten_index_getitem, "list_with_bool": aten_index_list_bool, - "getitem_with_bool": aten_index_getitem_bool + "getitem_with_bool": aten_index_getitem_bool, + "bool_with_axis": aten_index_bool_with_axis, } aten_index = cases[model] - ref_net = None - - return aten_index(), ref_net, "aten::index" + return aten_index(), None, "aten::index" @pytest.mark.nightly @pytest.mark.precommit @@ -68,43 +71,43 @@ def test_index(self, input_shape, idx, case, ie_device, precision, ir_version): ((1, 2), np.array([[1, 0]]).astype(bool)), ((2, 2, 5), np.zeros([2, 2, 5]).astype(bool)), ((2, 2, 5), np.ones([2, 2, 5]).astype(bool)), - ((2, 2, 5), np.random.rand(2, 2, 5) > 0) + ((2, 2, 5), np.array([[[1, 0, 1, 0, 1], [0, 1, 0, 1, 0]], + [[1, 1, 0, 0, 1], [0, 0, 1, 1, 0]]], dtype=bool)) ]) def test_index_bool(self, input_shape, idx, case, ie_device, precision, ir_version): self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx}) + @pytest.mark.nightly + @pytest.mark.precommit + def test_index_bool_with_axis(self, ie_device, precision, ir_version): + self._test(*self.create_model("bool_with_axis"), ie_device, precision, ir_version, + kwargs_to_prepare_input={"input_shape": (2, 2, 5)}, trace_model=True) + class TestIndexRange(PytorchLayerTest): def _prepare_input(self, input_shape, idx): - import numpy as np - return (np.random.randn(*input_shape).astype(np.float32), np.array(idx).astype(np.int32)) + rng = np.random.default_rng(42) + x = rng.standard_normal(size=input_shape, dtype=np.float32) + return (x, np.array(idx).astype(np.int32)) def create_model(self): - import torch - class aten_index_arange(torch.nn.Module): def forward(self, x, y): x = x.reshape(x.shape[0], -1) return x[torch.arange(x.shape[0]), y] - ref_net = None - - return aten_index_arange(), ref_net, "aten::index" + return aten_index_arange(), None, "aten::index" def create_model2(self): - import torch - class aten_index_arange(torch.nn.Module): def forward(self, x, y): x = x.reshape(x.shape[0], x.shape[1], -1, 1) return x[torch.arange(x.shape[0]), y] - ref_net = None - - return aten_index_arange(), ref_net, "aten::index" + return aten_index_arange(), None, "aten::index" @pytest.mark.nightly @pytest.mark.precommit @@ -131,8 +134,8 @@ def test_index_range_free_dims(self, input_shape, idx, ie_device, precision, ir_ class TestIndexMask(PytorchLayerTest): def _prepare_input(self, input_shape): - import numpy as np - return (np.random.randn(*input_shape).astype(np.float32),) + rng = np.random.default_rng(42) + return (rng.standard_normal(size=input_shape, dtype=np.float32),) def create_model(self): import torch @@ -141,9 +144,7 @@ class aten_index_mask(torch.nn.Module): def forward(self, x): return x[x > 0] - ref_net = None - - return aten_index_mask(), ref_net, "aten::index" + return aten_index_mask(), None, "aten::index" @pytest.mark.nightly @pytest.mark.precommit @@ -158,12 +159,12 @@ def test_index_mask(self, input_shape, ie_device, precision, ir_version): class TestIndexNone(PytorchLayerTest): def _prepare_input(self, input_shape): - import numpy as np - return (np.random.randn(*input_shape).astype(np.float32),) + rng = np.random.default_rng(42) + return (rng.standard_normal(size=input_shape, dtype=np.float32),) class aten_index_list(torch.nn.Module): def __init__(self, idxs): - super(TestIndexNone.aten_index_list, self).__init__() + super().__init__() self.idxs = idxs def forward(self, x):