diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 026b051b337104..237e5de74b4022 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -57,6 +57,7 @@ range, resize_, set_, + split_with_sizes, to_tensor, tril, tril_, @@ -949,6 +950,7 @@ 'greater', 'clamp', 'clamp_', + 'split_with_sizes', ] diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index dc879efedf16a2..c012918e590e2a 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -2556,13 +2556,15 @@ def triu_( @overload def meshgrid( - args: Sequence[paddle.Tensor], name: str | None = None + args: Sequence[paddle.Tensor], + name: str | None = None, + indexing: str | None = None, ) -> list[paddle.Tensor]: ... @overload def meshgrid( - *args: paddle.Tensor, name: str | None = None + *args: paddle.Tensor, name: str | None = None, indexing: str | None = None ) -> list[paddle.Tensor]: ... @@ -2577,7 +2579,9 @@ def meshgrid(*args, **kwargs): **kwargs (optional): Currently, only accept name in **kwargs The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - + indexing (Optional[str]) : the indexing mode, either “xy” or “ij”, defaults to “ij”.If “xy” is selected, the first dimension corresponds to the cardinality + of the second input and the second dimension corresponds to the cardinality of the first input. If “ij” is selected, the dimensions are in the + same order as the cardinality of the inputs. Returns: Tensor: k tensors. The shape of each tensor is (N1, N2, ..., Nk) @@ -2597,13 +2601,26 @@ def meshgrid(*args, **kwargs): [100, 200] """ + name = kwargs.get("name", None) + indexing = kwargs.pop("indexing", None) + if indexing is None: + indexing = "ij" if len(args) == 1 and isinstance(args[0], (list, tuple)): args = args[0] + + if indexing not in ("ij", "xy"): + raise ValueError( + f"meshgrid: indexing must be 'ij' or 'xy', but got {indexing}" + ) + + swap_xy = indexing == "xy" and len(args) >= 2 + if swap_xy: + args = (args[1], args[0], *args[2:]) + if in_dynamic_or_pir_mode(): - return _C_ops.meshgrid(list(args)) + out = _C_ops.meshgrid(list(args)) else: - name = kwargs.get("name", None) helper = LayerHelper('meshgrid', **locals()) if not isinstance(args, (list, tuple)): @@ -2637,7 +2654,59 @@ def meshgrid(*args, **kwargs): type='meshgrid', inputs={'X': list(args)}, outputs={'Out': out} ) - return out + if swap_xy: + out[0], out[1] = out[1], out[0] + return out + + +def split_with_sizes( + self: paddle.Tensor, split_sizes: list[int], dim: int = 0 +) -> list[paddle.Tensor]: + """ + Splits the input tensor into multiple sub tensors according to given split sizes. + + Args: + self (Tensor): The input tensor to be split. + split_sizes (list[int]): A list of non negative integers specifying + the sizes of each split along dimension ``dim``. The sum of all + elements in this list must equal the size of ``self`` along ``dim``. + dim (int, optional): The dimension along which to split the tensor. + Defaults to 0. + + Returns: + list[Tensor]: A list of sub tensors resulting from splitting ``self`` + along the specified dimension. + + Examples: + .. code-block:: python + + >>> import paddle + >>> x = paddle.to_tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + >>> # Split into two parts along the first dimension, of sizes 1 and 2 + >>> splits = paddle.Tensor.split_with_sizes(x, [1, 2], dim=0) + >>> print(splits) + """ + for size in split_sizes: + if size < 0: + raise ValueError( + "split_with_sizes expects split_sizes have only non-negative entries" + ) + + total = sum(split_sizes) + if total != self.shape[dim]: + raise ValueError( + f"Split sizes add up to {total} but got the tensor's size of {self.shape[dim]}" + ) + + outs = [] + start = 0 + for size in split_sizes: + end = start + size + out = paddle.slice(self, axes=[dim], starts=[start], ends=[end]) + outs.append(out) + start = end + + return outs def diag_embed( diff --git a/test/legacy_test/test_meshgrid_op.py b/test/legacy_test/test_meshgrid_op.py index 7442b114a348a6..4ff1627ccd9978 100644 --- a/test/legacy_test/test_meshgrid_op.py +++ b/test/legacy_test/test_meshgrid_op.py @@ -311,6 +311,74 @@ def test_api_with_dygraph(self): np.testing.assert_array_equal(res_4.shape, [100, 200]) +class TestMeshgridOpIndexing(unittest.TestCase): + def setUp(self): + self.input_3 = np.random.randint(0, 100, [100]).astype('int32') + self.input_4 = np.random.randint(0, 100, [200]).astype('int32') + + def test_api_with_dygraph_indexing_xy(self): + np_res_3, np_res_4 = np.meshgrid( + self.input_3, self.input_4, indexing='xy' + ) + + with base.dygraph.guard(): + tensor_3 = paddle.to_tensor(self.input_3) + tensor_4 = paddle.to_tensor(self.input_4) + res_3, res_4 = paddle.tensor.meshgrid( + tensor_3, tensor_4, indexing='xy' + ) + + np.testing.assert_array_equal(res_3.shape, np_res_3.shape) + np.testing.assert_array_equal(res_4.shape, np_res_4.shape) + np.testing.assert_array_equal(res_3.numpy(), np_res_3) + np.testing.assert_array_equal(res_3.numpy(), np_res_3) + np.testing.assert_array_equal(res_4.numpy(), np_res_4) + + def test_api_with_dygraph_indexing_ij(self): + np_res_3, np_res_4 = np.meshgrid( + self.input_3, self.input_4, indexing='ij' + ) + + with base.dygraph.guard(): + tensor_3 = paddle.to_tensor(self.input_3) + tensor_4 = paddle.to_tensor(self.input_4) + res_3, res_4 = paddle.tensor.meshgrid( + tensor_3, tensor_4, indexing='ij' + ) + + np.testing.assert_array_equal(res_3.shape, np_res_3.shape) + np.testing.assert_array_equal(res_4.shape, np_res_4.shape) + np.testing.assert_array_equal(res_3.numpy(), np_res_3) + np.testing.assert_array_equal(res_4.numpy(), np_res_4) + + def test_indexing_default(self): + np_res_3, np_res_4 = np.meshgrid( + self.input_3, self.input_4, indexing='ij' + ) + + with base.dygraph.guard(): + tensor_3 = paddle.to_tensor(self.input_3) + tensor_4 = paddle.to_tensor(self.input_4) + res_3, res_4 = paddle.tensor.meshgrid(tensor_3, tensor_4) + res_3_n, res_4_n = paddle.tensor.meshgrid( + tensor_3, tensor_4, indexing=None + ) + np.testing.assert_array_equal(res_3.numpy(), np_res_3) + np.testing.assert_array_equal(res_4.numpy(), np_res_4) + np.testing.assert_array_equal(res_3_n.numpy(), np_res_3) + np.testing.assert_array_equal(res_4_n.numpy(), np_res_4) + + def test_indexing_invalid_value(self): + with base.dygraph.guard(): + tensor_3 = paddle.to_tensor(self.input_3) + tensor_4 = paddle.to_tensor(self.input_4) + invalid_indexing = "ab" + with self.assertRaises(ValueError) as cm: + res_3, res_4 = paddle.tensor.meshgrid( + tensor_3, tensor_4, indexing=invalid_indexing + ) + + class TestMeshgridOp7(unittest.TestCase): def test_api_with_dygraph_list_input(self): input_3 = np.random.randint( diff --git a/test/legacy_test/test_split_with_sizes_api.py b/test/legacy_test/test_split_with_sizes_api.py new file mode 100644 index 00000000000000..8d03ab00c3287f --- /dev/null +++ b/test/legacy_test/test_split_with_sizes_api.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +import paddle + + +class TestSplitWithSizes(unittest.TestCase): + def setUp(self): + self.x = paddle.arange(12).reshape([3, 4]) + self.split_sizes = [1, 2] + self.dim = 0 + + def test_basic_functionality(self): + splits = paddle.Tensor.split_with_sizes( + self.x, self.split_sizes, dim=self.dim + ) + + self.assertEqual(len(splits), len(self.split_sizes)) + + expected_shapes = [[1, 4], [2, 4]] + for s, shape in zip(splits, expected_shapes): + self.assertListEqual(list(s.shape), shape) + + np_x = self.x.numpy() + start = 0 + for i, size in enumerate(self.split_sizes): + np_ref = np_x[start : start + size, :] + np.testing.assert_array_equal(splits[i].numpy(), np_ref) + start += size + + def test_ValueError_raises(self): + invalid_split_sizes = [1, -2] + with self.assertRaises(ValueError) as cm: + paddle.Tensor.split_with_sizes( + self.x, invalid_split_sizes, dim=self.dim + ) + + invalid_split_sizes = [1, 1] + with self.assertRaises(ValueError) as cm: + paddle.Tensor.split_with_sizes( + self.x, invalid_split_sizes, dim=self.dim + ) + + +if __name__ == "__main__": + unittest.main()