Skip to content

Commit cce2771

Browse files
authored
[API Compatibility] Support tensor dtype compare using is (#76155)
1 parent 8e2251f commit cce2771

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

paddle/fluid/pybind/eager_utils.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,9 +1205,14 @@ PyObject* ToPyObject(const phi::DenseTensor* value) {
12051205
}
12061206

12071207
PyObject* ToPyObject(const phi::DataType& dtype) {
1208-
auto obj = ::pybind11::cast(dtype);
1209-
obj.inc_ref();
1210-
return obj.ptr();
1208+
static const std::vector<std::string> dtype_names = {
1209+
"UNDEFINED", "BOOL", "UINT8", "INT8", "UINT16",
1210+
"INT16", "UINT32", "INT32", "UINT64", "INT64",
1211+
"FLOAT32", "FLOAT64", "COMPLEX64", "COMPLEX128", "PSTRING",
1212+
"FLOAT16", "BFLOAT16", "FLOAT8_E4M3FN", "FLOAT8_E5M2",
1213+
};
1214+
return PyObject_GetAttrString(reinterpret_cast<PyObject*>(g_data_type_pytype),
1215+
dtype_names[static_cast<int>(dtype)].c_str());
12111216
}
12121217

12131218
PyObject* ToPyObject(const std::vector<phi::DataType>& dtypes) {

test/legacy_test/test_eager_tensor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,20 @@ def test_print_tensor_dtype(self):
13461346

13471347
self.assertEqual(a_str, expected)
13481348

1349+
def test_tensor_dtype_compare(self):
1350+
a = paddle.randn([2], dtype="float32")
1351+
b = paddle.randn([2], dtype="float32")
1352+
c = paddle.randn([2], dtype="float64")
1353+
1354+
self.assertTrue(a.dtype == paddle.float32)
1355+
self.assertTrue(a.dtype == b.dtype)
1356+
self.assertTrue(a.dtype != paddle.float64)
1357+
self.assertTrue(a.dtype != c.dtype)
1358+
self.assertTrue(a.dtype is paddle.float32)
1359+
self.assertTrue(a.dtype is b.dtype)
1360+
self.assertTrue(a.dtype is not paddle.float64)
1361+
self.assertTrue(a.dtype is not c.dtype)
1362+
13491363
def test___cuda_array_interface__(self):
13501364
"""test Tensor.__cuda_array_interface__"""
13511365
with dygraph_guard():

0 commit comments

Comments
 (0)