Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/phi/ops/yaml/python_api_info.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@
args_mapper :
func : ArgSumMapper

- op : tanh
name : [paddle.tanh, paddle.Tensor.tanh, paddle.nn.functional.tanh]
args_alias:
use_default_mapping : True

- op : exp
name : [paddle.exp, paddle.Tensor.exp]
args_alias:
Expand Down
60 changes: 55 additions & 5 deletions python/paddle/_paddle_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,6 +2314,47 @@ def dot(
""",
)

add_doc_and_signature(
"tanh",
r"""

Tanh Activation Operator.

.. math::
out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}

.. note::
Alias Support:
1. The parameter name ``input`` can be used as an alias for ``x``.

Args:
x (Tensor): Input of Tanh operator, an N-D Tensor, with data type bfloat16, float32, float64,
float16, uint8, int8, int16, int32, int64. Alias: ``input``.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output tensor. Default: None.

Returns:
Output of Tanh operator, a Tensor with same data type and shape as input
(integer types are autocasted into float32).

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
>>> out = paddle.tanh(x)
>>> out
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
[-0.37994900, -0.19737528, 0.09966799, 0.29131261])
""",
"""
def tanh(
x: Tensor, *, out: Tensor | None = None, name: str | None = None,
) -> Tensor
""",
)

add_doc_and_signature(
"exp",
"""
Expand All @@ -2331,7 +2372,7 @@ def dot(
x (Tensor): Input of Exp operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128.
Alias: ``input``.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output tensor.
out (Tensor|None, optional): The output tensor. Default: None.

Returns:
Tensor. Output of Exp operator, a Tensor with shape same as input.
Expand Down Expand Up @@ -2371,7 +2412,7 @@ def exp(
x (Tensor): Input of Expm1 operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128.
Alias: ``input``.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output tensor.
out (Tensor|None, optional): The output tensor. Default: None.

Returns:
Tensor. Output of Expm1 operator, a Tensor with shape same as input.
Expand Down Expand Up @@ -2494,10 +2535,16 @@ def diagonal(
out.shape = [4]
out.data = [1., -1., 3., 1.]

.. note::
Alias Support:
1. The parameter name ``input`` can be used as an alias for ``x``.

Args:
x (Tensor): Input of Round operator, an N-D Tensor, with data type bfloat16, int32, int64, float32, float64, float16, complex64 or complex128.
Alias: ``input``.
decimals(int): Rounded decimal place (default: 0).
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output tensor. Default: None.

Returns:
Tensor. Output of Round operator, a Tensor with shape same as input.
Expand Down Expand Up @@ -2529,12 +2576,15 @@ def round(

out = |x|

.. note::
Alias Support:
1. The parameter name ``input`` can be used as an alias for ``x``.

Args:
x (Tensor): The input Tensor with data type int32, int64, float16, float32, float64, complex64 and complex128.
Alias: ``input``.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Keyword args:
out (Tensor|None, optional): The output tensor.
out (Tensor|None, optional): The output tensor. Default: None.

Returns:
Tensor.A Tensor with the same data type and shape as :math:`x`.
Expand Down
55 changes: 1 addition & 54 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
sign,
sin,
sum,
tanh,
)
from paddle.base.libpaddle import DataType
from paddle.common_ops_import import VarDesc, dygraph_utils
Expand Down Expand Up @@ -4454,60 +4455,6 @@ def prod(
return out


def tanh(x: Tensor, name: str | None = None) -> Tensor:
r"""
Tanh Activation Operator.

.. math::
out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}

Args:
x (Tensor): Input of Tanh operator, an N-D Tensor, with data type bfloat16, float32, float64,
float16, uint8, int8, int16, int32, int64.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Output of Tanh operator, a Tensor with same data type and shape as input
(integer types are autocasted into float32).

Examples:

.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
>>> out = paddle.tanh(x)
>>> out
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
[-0.37994900, -0.19737528, 0.09966799, 0.29131261])
"""
if in_dynamic_or_pir_mode():
return _C_ops.tanh(x)
else:
check_variable_and_dtype(
x,
'x',
[
'uint16',
'float16',
'float32',
'float64',
'uint8',
'int8',
'int16',
'int32',
'int64',
],
'tanh',
)
check_type(x, 'x', (Variable), 'tanh')
helper = LayerHelper('tanh', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='tanh', inputs={'X': x}, outputs={'Out': out})
return out


@inplace_apis_in_dygraph_only
def tanh_(x: Tensor, name: str | None = None) -> Tensor:
r"""
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6167,6 +6167,7 @@ class TestActivationAPI_Compatibility(unittest.TestCase):
("paddle.exp", np.exp, {'min_val': -1.0, 'max_val': 1.0}),
("paddle.expm1", np.expm1, {'min_val': -1.0, 'max_val': 1.0}),
("paddle.round", np.round, {'min_val': -5.0, 'max_val': 5.0}),
("paddle.tanh", np.tanh, {'min_val': -1.0, 'max_val': 1.0}),
]

def setUp(self):
Expand Down
12 changes: 6 additions & 6 deletions test/standalone_executor/test_standalone_custom_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def build_program():
matmul_out = data @ weight
bias = paddle.ones([1024, 2048], dtype='float32', name='bias')
add_out = paddle.add(matmul_out, bias, name='add_out')
# add_out -> [sub] -> sub_out -> [tanh] -> tanh_out
# add_out -> [sub] -> sub_out -> [silu] -> silu_out
sub_out = paddle.subtract(add_out, data, name='sub_out')
tanh_out = paddle.tanh(sub_out, name='tanh_out')
silu_out = paddle.nn.functional.silu(sub_out, name='silu_out')
bias_1 = paddle.add(bias, sub_out, name='bias_1')
out_before = paddle.tanh(bias_1, name='out_before')
out_last = paddle.subtract(tanh_out, data, name='out_last')
out_before = paddle.nn.functional.silu(bias_1, name='out_before')
out_last = paddle.subtract(silu_out, data, name='out_last')
out_last2 = out_last @ weight

out = paddle.add(out_before, out_last2, name='out')
Expand All @@ -64,9 +64,9 @@ class TestManualEvent(unittest.TestCase):
| | | |
| elementwise_sub(s1) |
| | | |
| tanh(s1) elementwise_add(s1)
| silu(s1) elementwise_add(s1)
| | |
elementwise_sub(s1) tanh(s1)
elementwise_sub(s1) silu(s1)
| |
matmul_v2(s1) |
| | ---split prog----
Expand Down
4 changes: 2 additions & 2 deletions test/standalone_executor/test_standalone_custom_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class TestCustomStream(unittest.TestCase):
| | | |
| elementwise_sub(cpu) |
| | | |
| tanh(cpu) elementwise_add(s2)
| silu(cpu) elementwise_add(s2)
| | |
elementwise_sub(s1) tanh(s2)
elementwise_sub(s1) silu(s2)
| |
elementwise_add(s2)
|
Expand Down
8 changes: 4 additions & 4 deletions test/standalone_executor/test_standalone_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def build_program():
bias = paddle.ones([4, 64], dtype='float32', name='bias')
add_out = paddle.add(matmul_out, bias, name='add_out')

# add_out -> [memcpy_d2h] -> add_out' -> [sub] -> sub_out -> [tanh] -> tanh_out
# add_out -> [memcpy_d2h] -> add_out' -> [sub] -> sub_out -> [silu] -> silu_out
with paddle.static.device_guard('cpu'):
sub_out = paddle.subtract(add_out, data, name='sub_out')
tanh_out = paddle.tanh(sub_out, name='tanh_out')
silu_out = paddle.nn.functional.silu(sub_out, name='silu_out')

with paddle.static.device_guard('gpu'):
bias_1 = paddle.add(bias, sub_out, name='bias_1')
out_before = paddle.tanh(bias_1, name='out_before')
out_last = paddle.subtract(tanh_out, data, name='out_last')
out_before = paddle.nn.functional.silu(bias_1, name='out_before')
out_last = paddle.subtract(silu_out, data, name='out_last')

out = paddle.add(out_before, out_last, name='out')
mean = paddle.mean(out, name='mean_out')
Expand Down
Loading