Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,12 +1738,15 @@ def uniform_(
return _C_ops.uniform_inplace_(x, min, max, seed, 0, 0, 1.0)


@param_one_alias(["shape", "size"])
def randint(
low: int = 0,
high: int | None = None,
shape: ShapeLike = [1],
dtype: DTypeLike | None = None,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor:
"""
Returns a Tensor filled with random integers from a discrete uniform
Expand All @@ -1760,12 +1763,14 @@ def randint(
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. Default is [1].
Alias: ``size``.
dtype (str|np.dtype|paddle.dtype|None, optional): The data type of the
output tensor. Supported data types: int32, int64. If ``dtype``
is None, the data type is int64. Default is None.
name (str|None, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): Optional output tensor. If provided, the result will be stored in this tensor.

Returns:
Tensor, A Tensor filled with random integers from a discrete uniform
Expand Down Expand Up @@ -1846,15 +1851,15 @@ def randint(
if in_dynamic_mode():
shape = paddle.utils.convert_shape_to_list(shape)
return _C_ops.randint(
low, high, shape, dtype, _current_expected_place()
low, high, shape, dtype, _current_expected_place(), out=out
)
elif in_pir_mode():
check_shape(shape, 'randint')
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
if paddle.utils._contain_var(shape):
shape = paddle.utils.get_int_tensor_list(shape)
return _C_ops.randint(
low, high, shape, dtype, _current_expected_place()
low, high, shape, dtype, _current_expected_place(), out=out
)
else:
check_shape(shape, 'randint')
Expand All @@ -1872,7 +1877,8 @@ def randint(
)

helper = LayerHelper("randint", **locals())
out = helper.create_variable_for_type_inference(dtype=dtype)
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs
)
Expand Down
125 changes: 125 additions & 0 deletions test/legacy_test/test_randint_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,130 @@ def test_static(self):
paddle.enable_static()


class TestRandintAliasAndOut(unittest.TestCase):
def test_alias_and_out(self):
paddle.disable_static()

# Test size alias (param_one_alias decorator: shape -> size)
result_1 = paddle.randint(5, size=[3, 4])
result_2 = paddle.randint(5, size=paddle.to_tensor([3, 4]))
self.assertEqual(result_1.shape, [3, 4])
self.assertEqual(result_2.shape, [3, 4])

# Test out parameter with int32 dtype
result_3 = paddle.randint(high=5, shape=[3, 4], dtype='int32')
out = paddle.zeros([3, 4], dtype='int32')
result_4 = paddle.randint(high=5, shape=[3, 4], dtype='int32', out=out)
self.assertTrue(paddle.equal_all(result_4, out))
self.assertEqual(result_4.dtype, paddle.int32)

# Test out parameter with int64 dtype
out_int64 = paddle.zeros([2, 5], dtype='int64')
result_5 = paddle.randint(
high=10, shape=[2, 5], dtype='int64', out=out_int64
)
self.assertTrue(paddle.equal_all(result_5, out_int64))
self.assertEqual(result_5.dtype, paddle.int64)

# Test WITHOUT out parameter (out=None, triggers 'if out is None' branch)
result_6 = paddle.randint(high=5, shape=[3, 4], dtype='int32')
self.assertEqual(result_6.shape, [3, 4])
self.assertEqual(result_6.dtype, paddle.int32)

result_7 = paddle.randint(high=5, shape=[2, 3], dtype='int64')
self.assertEqual(result_7.shape, [2, 3])
self.assertEqual(result_7.dtype, paddle.int64)

paddle.enable_static()

def test_out_static_mode(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# In static mode (PIR), out parameter is not supported (as shown by warning)
# Test creates new tensor (out=None), triggering 'if out is None' branch
result1 = paddle.randint(high=5, shape=[3, 4], dtype='int32')
self.assertEqual(result1.shape, (3, 4))

result2 = paddle.randint(high=10, shape=[2, 5], dtype='int64')
self.assertEqual(result2.shape, (2, 5))

def test_size_alias_static_mode(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# Test size parameter as an alias for shape in static mode
result = paddle.randint(high=5, size=[3, 4], dtype='int32')
self.assertEqual(result.shape, (3, 4))


class TestRandintOldStaticMode(unittest.TestCase):
"""Test randint in old static graph mode (non-PIR mode).

This test specifically covers the else branch in randint:
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)

This branch is only executed when:
1. Not in dynamic mode (in_dynamic_mode() returns False)
2. Not in PIR mode (in_pir_mode() returns False)
"""

def test_out_none_old_static_mode(self):
"""Test that 'if out is None' branch is covered in old static mode."""
from paddle.pir_utils import OldIrGuard

with OldIrGuard():
main_program = paddle.static.Program()
startup_program = paddle.static.Program()

with paddle.static.program_guard(main_program, startup_program):
# This should go through the else branch (old static mode)
# and trigger 'if out is None: out = helper.create_variable_for_type_inference(dtype=dtype)'
result1 = paddle.randint(high=5, shape=[3, 4], dtype='int32')
result2 = paddle.randint(high=10, shape=[2, 5], dtype='int64')

# Verify shapes are correct
self.assertEqual(result1.shape, (3, 4))
self.assertEqual(result2.shape, (2, 5))

# Execute the program to verify it works
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
outs = exe.run(main_program, fetch_list=[result1, result2])

# Verify the outputs
self.assertEqual(outs[0].shape, (3, 4))
self.assertEqual(outs[1].shape, (2, 5))
# Verify values are in expected range
self.assertTrue(np.all(outs[0] >= 0) and np.all(outs[0] < 5))
self.assertTrue(np.all(outs[1] >= 0) and np.all(outs[1] < 10))

def test_size_alias_old_static_mode(self):
"""Test size alias in old static mode."""
from paddle.pir_utils import OldIrGuard

with OldIrGuard():
main_program = paddle.static.Program()
startup_program = paddle.static.Program()

with paddle.static.program_guard(main_program, startup_program):
# Test using 'size' parameter alias
result = paddle.randint(high=5, size=[4, 5], dtype='int32')
self.assertEqual(result.shape, (4, 5))

# Execute the program
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
outs = exe.run(main_program, fetch_list=[result])

self.assertEqual(outs[0].shape, (4, 5))
self.assertTrue(np.all(outs[0] >= 0) and np.all(outs[0] < 5))


if __name__ == "__main__":
unittest.main()
Loading