Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/phi/ops/yaml/python_api_info.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,16 @@
x : [input]
y : [values]

- op : i0
name : [paddle.i0,paddle.Tensor.i0]
args_alias :
use_default_mapping : True

- op : i0e
name : [paddle.i0e,paddle.Tensor.i0e]
args_alias :
use_default_mapping : True

- op : i1
name : [paddle.i1,paddle.Tensor.i1]
args_alias :
Expand Down
79 changes: 79 additions & 0 deletions python/paddle/_paddle_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,85 @@ def softplus(
""",
)

add_doc_and_signature(
"i0",
"""
The function used to calculate modified bessel function of order 0.

Equation:
.. math::

I_0(x) = \\sum^{\\infty}_{k=0}\frac{(x^2/4)^k}{(k!)^2}

Args:
x (Tensor): The input tensor, it's data type should be float32, float64,
uint8, int8, int16, int32, int64.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
- out (Tensor), A Tensor. the value of the modified bessel function of order 0 at x
(integer types are autocasted into float32).

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32")
>>> paddle.i0(x)
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.99999994 , 1.26606596 , 2.27958512 , 4.88079262 , 11.30192089])
""",
"""
def i0(
x: Tensor,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor
""",
)

add_doc_and_signature(
"i0e",
"""
The function used to calculate exponentially scaled modified Bessel function of order 0.

Equation:
.. math::

I_0(x) = \\sum^{\\infty}_{k=0}\frac{(x^2/4)^k}{(k!)^2} \\
I_{0e}(x) = e^{-|x|}I_0(x)

Args:
x (Tensor): The input tensor, it's data type should be float32, float64,
uint8, int8, int16, int32, int64.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
- out (Tensor), A Tensor. the value of the exponentially scaled modified Bessel function of order 0 at x
(integer types are autocasted into float32).

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32")
>>> print(paddle.i0e(x))
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.99999994, 0.46575963, 0.30850831, 0.24300036, 0.20700191])
""",
"""
def i0e(
x: Tensor,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor
""",
)

add_doc_and_signature(
"isclose",
r"""
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.

from .tensor.compat_softmax import softmax
from .tensor.math import i1, i1e, logsumexp
from .tensor.math import i0, i0e, i1, i1e, logsumexp
from .tensor.ops import expm1

__all__ = [
"i0",
"i0e",
"i1",
"i1e",
"logsumexp",
Expand Down
91 changes: 2 additions & 89 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
fmax,
fmin,
heaviside,
i0,
i0e,
i1,
i1e,
isfinite,
Expand Down Expand Up @@ -6190,50 +6192,6 @@ def vander(
return res


def i0(x: Tensor, name: str | None = None) -> Tensor:
r"""
The function used to calculate modified bessel function of order 0.

Equation:
.. math::

I_0(x) = \sum^{\infty}_{k=0}\frac{(x^2/4)^k}{(k!)^2}

Args:
x (Tensor): The input tensor, it's data type should be float32, float64,
uint8, int8, int16, int32, int64.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
- out (Tensor), A Tensor. the value of the modified bessel function of order 0 at x
(integer types are autocasted into float32).

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32")
>>> paddle.i0(x)
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.99999994 , 1.26606596 , 2.27958512 , 4.88079262 , 11.30192089])
"""
if in_dynamic_or_pir_mode():
return _C_ops.i0(x)
else:
check_variable_and_dtype(
x,
"x",
["float32", "float64", "uint8", "int8", "int16", "int32", "int64"],
"i0",
)

helper = LayerHelper("i0", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='i0', inputs={'x': x}, outputs={'out': out})
return out


@inplace_apis_in_dygraph_only
def i0_(x: Tensor, name: str | None = None) -> Tensor:
r"""
Expand All @@ -6245,51 +6203,6 @@ def i0_(x: Tensor, name: str | None = None) -> Tensor:
return _C_ops.i0_(x)


def i0e(x: Tensor, name: str | None = None) -> Tensor:
r"""
The function used to calculate exponentially scaled modified Bessel function of order 0.

Equation:
.. math::

I_0(x) = \sum^{\infty}_{k=0}\frac{(x^2/4)^k}{(k!)^2} \\
I_{0e}(x) = e^{-|x|}I_0(x)

Args:
x (Tensor): The input tensor, it's data type should be float32, float64,
uint8, int8, int16, int32, int64.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
- out (Tensor), A Tensor. the value of the exponentially scaled modified Bessel function of order 0 at x
(integer types are autocasted into float32).

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32")
>>> print(paddle.i0e(x))
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.99999994, 0.46575963, 0.30850831, 0.24300036, 0.20700191])
"""
if in_dynamic_or_pir_mode():
return _C_ops.i0e(x)
else:
check_variable_and_dtype(
x,
"x",
["float32", "float64", "uint8", "int8", "int16", "int32", "int64"],
"i0e",
)

helper = LayerHelper("i0e", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='i0e', inputs={'x': x}, outputs={'out': out})
return out


def polygamma(x: Tensor, n: int, name: str | None = None) -> Tensor:
r"""
Calculates the polygamma of the given input tensor, element-wise.
Expand Down
67 changes: 67 additions & 0 deletions test/legacy_test/test_i0_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,72 @@ def test_check_grad(self):
self.check_grad(['x'], 'out')


class TestI0API_Compatibility(unittest.TestCase):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5]

def setUp(self):
self.x = np.array(self.DATA).astype(self.DTYPE)
self.out = output_i0(self.x)
self.place = get_places()

def test_dygraph_Compatibility(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
paddle_dygraph_out = []
# Position args (args)
out1 = paddle.i0(x)
paddle_dygraph_out.append(out1)
# Key words args (kwargs) for paddle
out2 = paddle.i0(x=x)
paddle_dygraph_out.append(out2)
# Key words args for torch
out3 = paddle.i0(input=x)
paddle_dygraph_out.append(out3)

# Tensor method kwargs
out4 = x.i0()
paddle_dygraph_out.append(out4)
# Test out
out5 = paddle.empty([])
paddle.i0(x, out=out5)
paddle_dygraph_out.append(out5)
# scipy reference out
ref_out = output_i0(self.x)
# Check
for out in paddle_dygraph_out:
np.testing.assert_allclose(out.numpy(), ref_out, rtol=1e-5)
paddle.enable_static()

def test_static_Compatibility(self):
def run(place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.DTYPE
)
# Position args (args)
out1 = paddle.i0(x)
# Key words args (kwargs) for paddle
out2 = paddle.i0(x=x)
# Key words args for torch
out3 = paddle.i0(input=x)
# Tensor method args
out4 = x.i0()

exe = paddle.static.Executor(place)
fetches = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x},
fetch_list=[out1, out2, out3, out4],
)
for out in fetches:
np.testing.assert_allclose(out, self.out, rtol=1e-5)
paddle.disable_static()

for place in self.place:
run(place)


if __name__ == "__main__":
unittest.main()
67 changes: 67 additions & 0 deletions test/legacy_test/test_i0e_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,72 @@ def test_check_grad(self):
self.check_grad(['x'], 'out')


class TestI0eAPI_Compatibility(unittest.TestCase):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5]

def setUp(self):
self.x = np.array(self.DATA).astype(self.DTYPE)
self.out = output_i0e(self.x)
self.place = get_places()

def test_dygraph_Compatibility(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
paddle_dygraph_out = []
# Position args (args)
out1 = paddle.i0e(x)
paddle_dygraph_out.append(out1)
# Key words args (kwargs) for paddle
out2 = paddle.i0e(x=x)
paddle_dygraph_out.append(out2)
# Key words args for torch
out3 = paddle.i0e(input=x)
paddle_dygraph_out.append(out3)

# Tensor method kwargs
out4 = x.i0e()
paddle_dygraph_out.append(out4)
# Test out
out5 = paddle.empty([])
paddle.i0e(x, out=out5)
paddle_dygraph_out.append(out5)
# scipy reference out
ref_out = output_i0e(self.x)
# Check
for out in paddle_dygraph_out:
np.testing.assert_allclose(out.numpy(), ref_out, rtol=1e-5)
paddle.enable_static()

def test_static_Compatibility(self):
def run(place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.DTYPE
)
# Position args (args)
out1 = paddle.i0e(x)
# Key words args (kwargs) for paddle
out2 = paddle.i0e(x=x)
# Key words args for torch
out3 = paddle.i0e(input=x)
# Tensor method args
out4 = x.i0e()

exe = paddle.static.Executor(place)
fetches = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x},
fetch_list=[out1, out2, out3, out4],
)
for out in fetches:
np.testing.assert_allclose(out, self.out, rtol=1e-5)
paddle.disable_static()

for place in self.place:
run(place)


if __name__ == "__main__":
unittest.main()
Loading