diff --git a/paddle/phi/ops/yaml/python_api_info.yaml b/paddle/phi/ops/yaml/python_api_info.yaml index 93b415d867396d..217e68aae76908 100644 --- a/paddle/phi/ops/yaml/python_api_info.yaml +++ b/paddle/phi/ops/yaml/python_api_info.yaml @@ -12,10 +12,12 @@ name : [paddle.matmul,paddle.Tensor.matmul] args_alias : use_default_mapping : True + - op : multiply name : [paddle.multiply,paddle.Tensor.multiply] args_alias : use_default_mapping : True + - op : log2 name : [paddle.log2,paddle.Tensor.log2] args_alias : @@ -88,6 +90,7 @@ name : [paddle.all,paddle.Tensor.all] args_alias: use_default_mapping : True + - op : bmm name : [paddle.bmm, paddle.Tensor.bmm] args_alias: @@ -194,6 +197,28 @@ args_mapper : func : ArgSumMapper +- op : exp + name : [paddle.exp, paddle.Tensor.exp] + args_alias: + use_default_mapping : True + +- op : expm1 + name : [paddle.expm1, paddle.Tensor.expm1] + args_alias: + use_default_mapping : True + +- op : diagonal + name : [paddle.diagonal, paddle.Tensor.diagonal] + args_alias: + x : [input] + axis1 : [dim1] + axis2 : [dim2] + +- op : round + name : [paddle.round, paddle.Tensor.round] + args_alias: + use_default_mapping : True + - op : abs name : [paddle.abs, paddle.Tensor.abs] args_alias: diff --git a/python/paddle/_paddle_docs.py b/python/paddle/_paddle_docs.py index 1bf9a0f5cdc81c..c440d3f4c35ea4 100644 --- a/python/paddle/_paddle_docs.py +++ b/python/paddle/_paddle_docs.py @@ -2314,6 +2314,212 @@ def dot( """, ) +add_doc_and_signature( + "exp", + """ + + Computes exp of x element-wise with a natural number `e` as the base. + + .. math:: + out = e^x + + .. note:: + Alias Support: + 1. The parameter name ``input`` can be used as an alias for ``x``. + + Args: + x (Tensor): Input of Exp operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128. + Alias: ``input``. + name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + out (Tensor|None, optional): The output tensor. + + Returns: + Tensor. Output of Exp operator, a Tensor with shape same as input. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3]) + >>> out = paddle.exp(x) + >>> print(out) + Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.67032003, 0.81873077, 1.10517097, 1.34985888]) +""", + """ +def exp( + x: Tensor, *, out: Tensor | None = None, name: str | None = None +) -> Tensor +""", +) + +add_doc_and_signature( + "expm1", + """ + + Expm1 Operator. Computes expm1 of x element-wise with a natural number :math:`e` as the base. + + .. math:: + out = e^x - 1 + + .. note:: + Alias Support: + 1. The parameter name ``input`` can be used as an alias for ``x``. + + Args: + x (Tensor): Input of Expm1 operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128. + Alias: ``input``. + name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + out (Tensor|None, optional): The output tensor. + + Returns: + Tensor. Output of Expm1 operator, a Tensor with shape same as input. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3]) + >>> out = paddle.expm1(x) + >>> print(out) + Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, + [-0.32967997, -0.18126924, 0.10517092, 0.34985882]) +""", + """ +def expm1( + x: Tensor, *, out: Tensor | None = None, name: str | None = None +) -> Tensor +""", +) + +add_doc_and_signature( + "diagonal", + """ + + Computes the diagonals of the input tensor x. + + If ``x`` is 2D, returns the diagonal. + If ``x`` has larger dimensions, diagonals be taken from the 2D planes specified by axis1 and axis2. + By default, the 2D planes formed by the first and second axis of the input tensor x. + + The argument ``offset`` determines where diagonals are taken from input tensor x: + + - If offset = 0, it is the main diagonal. + - If offset > 0, it is above the main diagonal. + - If offset < 0, it is below the main diagonal. + + .. note:: + Alias Support: + 1. The parameter name ``input`` can be used as an alias for ``x``. + 2. The parameter name ``dim1`` can be used as an alias for ``axis1``. + 3. The parameter name ``dim2`` can be used as an alias for ``axis2``. + + Args: + x (Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be bool, int32, + int64, bfloat16, float16, float32, float64. Alias: ``input``. + offset (int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals). + axis1 (int, optional): The first axis with respect to take diagonal. Default: 0. Alias: ``dim1``. + axis2 (int, optional): The second axis with respect to take diagonal. Default: 1. Alias: ``dim2``. + name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: a partial view of input tensor in specify two dimensions, the output data type is the same as input data type. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.seed(2023) + >>> x = paddle.rand([2, 2, 3],'float32') + >>> print(x) + Tensor(shape=[2, 2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[0.86583614, 0.52014720, 0.25960937], + [0.90525323, 0.42400089, 0.40641287]], + [[0.97020894, 0.74437362, 0.51785129], + [0.73292869, 0.97786582, 0.04315904]]]) + + >>> out1 = paddle.diagonal(x) + >>> print(out1) + Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0.86583614, 0.73292869], + [0.52014720, 0.97786582], + [0.25960937, 0.04315904]]) + + >>> out2 = paddle.diagonal(x, offset=0, axis1=2, axis2=1) + >>> print(out2) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0.86583614, 0.42400089], + [0.97020894, 0.97786582]]) + + >>> out3 = paddle.diagonal(x, offset=1, axis1=0, axis2=1) + >>> print(out3) + Tensor(shape=[3, 1], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0.90525323], + [0.42400089], + [0.40641287]]) + + >>> out4 = paddle.diagonal(x, offset=0, axis1=1, axis2=2) + >>> print(out4) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0.86583614, 0.42400089], + [0.97020894, 0.97786582]]) +""", + """ +def diagonal( + x: Tensor, + offset: int = 0, + axis1: int = 0, + axis2: int = 1, + name: str | None = None, +) -> Tensor +""", +) + +add_doc_and_signature( + "round", + """ + + Round the values in the input to the nearest integer value. + + .. code-block:: text + + input: + x.shape = [4] + x.data = [1.2, -0.9, 3.4, 0.9] + + output: + out.shape = [4] + out.data = [1., -1., 3., 1.] + + Args: + x (Tensor): Input of Round operator, an N-D Tensor, with data type bfloat16, int32, int64, float32, float64, float16, complex64 or complex128. + decimals(int): Rounded decimal place (default: 0). + name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. Output of Round operator, a Tensor with shape same as input. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([-0.5, -0.2, 0.6, 1.5]) + >>> out = paddle.round(x) + >>> print(out) + Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, + [-0., -0., 1., 2.]) +""", + """ +def round( + x: Tensor, decimals = 0, *, out: Tensor | None = None, name: str | None = None, +) -> Tensor +""", +) + add_doc_and_signature( "abs", """ @@ -2350,7 +2556,6 @@ def abs( ) -> Tensor """, ) - # lubingxin # chenhuangrun diff --git a/python/paddle/special.py b/python/paddle/special.py index dc0d1661aacf21..68d420644eeacb 100644 --- a/python/paddle/special.py +++ b/python/paddle/special.py @@ -14,8 +14,10 @@ from .tensor.compat_softmax import softmax from .tensor.math import logsumexp +from .tensor.ops import expm1 __all__ = [ "logsumexp", "softmax", + "expm1", ] diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 559f5f62ee5f00..d5796745215e53 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -21,7 +21,7 @@ import paddle from paddle import _C_ops -from paddle._C_ops import bmm, dot, matmul # noqa: F401 +from paddle._C_ops import bmm, diagonal, dot, matmul # noqa: F401 from paddle.base.libpaddle import DataType from paddle.common_ops_import import VarDesc from paddle.tensor.math import broadcast_shape @@ -5774,131 +5774,3 @@ def cholesky_inverse( else: A = x @ x.T return paddle.linalg.inv(A) - - -def diagonal( - x: Tensor, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - name: str | None = None, -) -> Tensor: - """ - Computes the diagonals of the input tensor x. - - If ``x`` is 2D, returns the diagonal. - If ``x`` has larger dimensions, diagonals be taken from the 2D planes specified by axis1 and axis2. - By default, the 2D planes formed by the first and second axis of the input tensor x. - - The argument ``offset`` determines where diagonals are taken from input tensor x: - - - If offset = 0, it is the main diagonal. - - If offset > 0, it is above the main diagonal. - - If offset < 0, it is below the main diagonal. - - Args: - x (Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be bool, int32, - int64, bfloat16, float16, float32, float64. - offset (int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals). - axis1 (int, optional): The first axis with respect to take diagonal. Default: 0. - axis2 (int, optional): The second axis with respect to take diagonal. Default: 1. - name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor: a partial view of input tensor in specify two dimensions, the output data type is the same as input data type. - - Examples: - .. code-block:: python - - >>> import paddle - - >>> paddle.seed(2023) - >>> x = paddle.rand([2, 2, 3],'float32') - >>> print(x) - Tensor(shape=[2, 2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, - [[[0.86583614, 0.52014720, 0.25960937], - [0.90525323, 0.42400089, 0.40641287]], - [[0.97020894, 0.74437362, 0.51785129], - [0.73292869, 0.97786582, 0.04315904]]]) - - >>> out1 = paddle.diagonal(x) - >>> print(out1) - Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0.86583614, 0.73292869], - [0.52014720, 0.97786582], - [0.25960937, 0.04315904]]) - - >>> out2 = paddle.diagonal(x, offset=0, axis1=2, axis2=1) - >>> print(out2) - Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0.86583614, 0.42400089], - [0.97020894, 0.97786582]]) - - >>> out3 = paddle.diagonal(x, offset=1, axis1=0, axis2=1) - >>> print(out3) - Tensor(shape=[3, 1], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0.90525323], - [0.42400089], - [0.40641287]]) - - >>> out4 = paddle.diagonal(x, offset=0, axis1=1, axis2=2) - >>> print(out4) - Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0.86583614, 0.42400089], - [0.97020894, 0.97786582]]) - - """ - if in_dynamic_or_pir_mode(): - return _C_ops.diagonal(x, offset, axis1, axis2) - else: - - def __check_input(x, offset, axis1, axis2): - check_dtype( - x.dtype, - 'Input', - [ - 'bool', - 'int32', - 'int64', - 'float16', - 'uint16', - 'float32', - 'float64', - ], - 'diagonal', - ) - - input_shape = list(x.shape) - assert len(input_shape) >= 2, ( - "The x must be at least 2-dimensional, " - f"But received Input x's dimensional: {len(input_shape)}.\n" - ) - - axis1_ = axis1 if axis1 >= 0 else len(input_shape) + axis1 - axis2_ = axis2 if axis2 >= 0 else len(input_shape) + axis2 - - assert axis1_ < len(input_shape), ( - f"The argument axis1 is out of range (expected to be in range of [{-(len(input_shape))}, {len(input_shape) - 1}], but got {axis1}).\n" - ) - - assert axis2_ < len(input_shape), ( - f"The argument axis2 is out of range (expected to be in range of [{-(len(input_shape))}, {len(input_shape) - 1}], but got {axis2}).\n" - ) - - assert axis1_ != axis2_, ( - "axis1 and axis2 cannot be the same axis." - f"But received axis1 = {axis1}, axis2 = {axis2}\n" - ) - - __check_input(x, offset, axis1, axis2) - helper = LayerHelper('diagonal', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type='diagonal', - inputs={'Input': [x]}, - attrs={'offset': offset, 'axis1': axis1, 'axis2': axis2}, - outputs={'Out': [out]}, - ) - - return out diff --git a/python/paddle/tensor/ops.py b/python/paddle/tensor/ops.py index 7ba6b546eea523..14df627ed0c8da 100644 --- a/python/paddle/tensor/ops.py +++ b/python/paddle/tensor/ops.py @@ -19,7 +19,10 @@ abs, ceil, cos, + exp, + expm1, floor, + round, rsqrt, sigmoid, sin, @@ -457,106 +460,6 @@ def cosh(x: Tensor, name: str | None = None) -> Tensor: return out -def exp(x: Tensor, name: str | None = None) -> Tensor: - """ - - Computes exp of x element-wise with a natural number `e` as the base. - - .. math:: - out = e^x - - Args: - x (Tensor): Input of Exp operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128. - name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor. Output of Exp operator, a Tensor with shape same as input. - - Examples: - .. code-block:: python - - >>> import paddle - - >>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3]) - >>> out = paddle.exp(x) - >>> print(out) - Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, - [0.67032003, 0.81873077, 1.10517097, 1.34985888]) - """ - if in_dynamic_or_pir_mode(): - return _C_ops.exp(x) - else: - check_variable_and_dtype( - x, - 'x', - [ - 'int32', - 'int64', - 'uint16', - 'float16', - 'float32', - 'float64', - 'complex64', - 'complex128', - ], - 'exp', - ) - helper = LayerHelper('exp', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op(type='exp', inputs={"X": x}, outputs={"Out": out}) - return out - - -def expm1(x: Tensor, name: str | None = None) -> Tensor: - """ - - Expm1 Operator. Computes expm1 of x element-wise with a natural number :math:`e` as the base. - - .. math:: - out = e^x - 1 - - Args: - x (Tensor): Input of Expm1 operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128. - name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor. Output of Expm1 operator, a Tensor with shape same as input. - - Examples: - .. code-block:: python - - >>> import paddle - - >>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3]) - >>> out = paddle.expm1(x) - >>> print(out) - Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, - [-0.32967997, -0.18126924, 0.10517092, 0.34985882]) - """ - if in_dynamic_or_pir_mode(): - return _C_ops.expm1(x) - else: - check_variable_and_dtype( - x, - 'x', - [ - 'float16', - 'uint16', - 'float32', - 'float64', - 'int32', - 'int64', - 'complex64', - 'complex128', - ], - 'expm1', - ) - helper = LayerHelper('expm1', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op(type='expm1', inputs={"X": x}, outputs={"Out": out}) - return out - - def reciprocal(x: Tensor, name: str | None = None) -> Tensor: """ @@ -612,69 +515,6 @@ def reciprocal(x: Tensor, name: str | None = None) -> Tensor: return out -def round(x: Tensor, decimals: int = 0, name: str | None = None) -> Tensor: - """ - - Round the values in the input to the nearest integer value. - - .. code-block:: text - - input: - x.shape = [4] - x.data = [1.2, -0.9, 3.4, 0.9] - - output: - out.shape = [4] - out.data = [1., -1., 3., 1.] - - Args: - x (Tensor): Input of Round operator, an N-D Tensor, with data type bfloat16, int32, int64, float32, float64, float16, complex64 or complex128. - decimals(int): Rounded decimal place (default: 0). - name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor. Output of Round operator, a Tensor with shape same as input. - - Examples: - .. code-block:: python - - >>> import paddle - - >>> x = paddle.to_tensor([-0.5, -0.2, 0.6, 1.5]) - >>> out = paddle.round(x) - >>> print(out) - Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, - [-0., -0., 1., 2.]) - """ - if in_dynamic_or_pir_mode(): - return _C_ops.round(x, decimals) - else: - check_variable_and_dtype( - x, - 'x', - [ - 'float16', - 'uint16', - 'int32', - 'int64', - 'float32', - 'float64', - 'complex64', - 'complex128', - ], - 'round', - ) - helper = LayerHelper('round', **locals()) - attrs = { - 'decimals': int(decimals), - } - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type='round', inputs={"X": x}, outputs={"Out": out}, attrs=attrs - ) - return out - - @inplace_apis_in_dygraph_only def round_(x, decimals=0, name=None): r""" diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index a72407c157555f..e85d7358fca7c5 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -6164,6 +6164,9 @@ class TestActivationAPI_Compatibility(unittest.TestCase): ACTIVATION_CONFIGS = [ ("paddle.abs", np.abs, {'min_val': -1.0, 'max_val': 1.0}), ("paddle.log2", np.log2, {'min_val': 0.0, 'max_val': 8.0}), + ("paddle.exp", np.exp, {'min_val': -1.0, 'max_val': 1.0}), + ("paddle.expm1", np.expm1, {'min_val': -1.0, 'max_val': 1.0}), + ("paddle.round", np.round, {'min_val': -5.0, 'max_val': 5.0}), ] def setUp(self): diff --git a/test/legacy_test/test_diagonal_op.py b/test/legacy_test/test_diagonal_op.py index 68d8f683a3d6b2..d61dc520efaff7 100644 --- a/test/legacy_test/test_diagonal_op.py +++ b/test/legacy_test/test_diagonal_op.py @@ -234,5 +234,108 @@ def init_config(self): ).copy() +class TestDiagonalAPI_Compatibility(unittest.TestCase): + def setUp(self): + np.random.seed(123) + paddle.enable_static() + self.shape = [5, 6, 7] + self.dtype = 'float32' + self.init_data() + + def init_data(self): + self.np_input = np.random.rand(*self.shape).astype(self.dtype) + + def test_dygraph_Compatibility(self): + paddle.disable_static() + x = paddle.to_tensor(self.np_input) + paddle_dygraph_out = [] + # Position args (args) + out1 = paddle.diagonal(x) + paddle_dygraph_out.append(out1) + # Key words args for paddle + out2 = paddle.diagonal(x=x, offset=1, axis1=0, axis2=2) + paddle_dygraph_out.append(out2) + # Key words args for torch + out3 = paddle.diagonal(input=x, offset=-1, dim1=1, dim2=2) + paddle_dygraph_out.append(out3) + # Mixed args - paddle parameters prioritized + out4 = paddle.diagonal(x, offset=0, axis1=1, axis2=2) + paddle_dygraph_out.append(out4) + # Mixed args - torch parameters prioritized + out5 = paddle.diagonal(input=x, offset=0, dim1=1, dim2=2) + paddle_dygraph_out.append(out5) + # Tensor method args + out6 = x.diagonal() + paddle_dygraph_out.append(out6) + # Tensor method kwargs + out7 = x.diagonal(offset=2, dim1=0, dim2=1) + paddle_dygraph_out.append(out7) + + ref_out1 = np.diagonal(self.np_input) + ref_out2 = np.diagonal(self.np_input, offset=1, axis1=0, axis2=2) + ref_out3 = np.diagonal(self.np_input, offset=-1, axis1=1, axis2=2) + ref_out4 = np.diagonal(self.np_input, offset=0, axis1=1, axis2=2) + ref_out5 = np.diagonal(self.np_input, offset=0, axis1=1, axis2=2) + ref_out6 = np.diagonal(self.np_input) + ref_out7 = np.diagonal(self.np_input, offset=2, axis1=0, axis2=1) + ref_outs = [ + ref_out1, + ref_out2, + ref_out3, + ref_out4, + ref_out5, + ref_out6, + ref_out7, + ] + for out, ref_out in zip(paddle_dygraph_out, ref_outs): + np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-6) + paddle.enable_static() + + def test_static_Compatibility(self): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.base.program_guard(main, startup): + x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype) + # Position args (args) + out1 = paddle.diagonal(x) + # Key words args for paddle + out2 = paddle.diagonal(x=x, offset=1, axis1=0, axis2=2) + # Key words args for torch + out3 = paddle.diagonal(input=x, offset=-1, dim1=1, dim2=2) + # Mixed args - paddle parameters prioritized + out4 = paddle.diagonal(x, offset=0, axis1=1, axis2=2) + # Mixed args - torch parameters prioritized + out5 = paddle.diagonal(input=x, offset=0, dim1=1, dim2=2) + # Tensor method args + out6 = x.diagonal() + # Tensor method kwargs + out7 = x.diagonal(offset=2, dim1=0, dim2=1) + + exe = paddle.base.Executor(paddle.CPUPlace()) + fetches = exe.run( + main, + feed={"x": self.np_input}, + fetch_list=[out1, out2, out3, out4, out5, out6, out7], + ) + ref_out1 = np.diagonal(self.np_input) + ref_out2 = np.diagonal(self.np_input, offset=1, axis1=0, axis2=2) + ref_out3 = np.diagonal(self.np_input, offset=-1, axis1=1, axis2=2) + ref_out4 = np.diagonal(self.np_input, offset=0, axis1=1, axis2=2) + ref_out5 = np.diagonal(self.np_input, offset=0, axis1=1, axis2=2) + ref_out6 = np.diagonal(self.np_input) + ref_out7 = np.diagonal(self.np_input, offset=2, axis1=0, axis2=1) + ref_outs = [ + ref_out1, + ref_out2, + ref_out3, + ref_out4, + ref_out5, + ref_out6, + ref_out7, + ] + for out, ref_out in zip(fetches, ref_outs): + np.testing.assert_allclose(out, ref_out, rtol=1e-6) + + if __name__ == '__main__': unittest.main()