diff --git a/test/legacy_test/test_stack_extension_api.py b/test/legacy_test/test_stack_extension_api.py index 462f6f82523524..ae40ff91bb9fb3 100644 --- a/test/legacy_test/test_stack_extension_api.py +++ b/test/legacy_test/test_stack_extension_api.py @@ -89,6 +89,10 @@ def _test_static_api( names: list, ): """Test `static`, convert `Tensor` to `numpy array` before feed into graph""" + # convert grad value to bool if dtype is bool + grad_value = 123.0 if dtypes[0] != 'bool' else True + if dtypes[0] == 'bfloat16': + grad_value = paddle.to_tensor(grad_value, dtype=dtypes[0]).numpy() paddle.enable_static() for device, place in PLACES: @@ -130,8 +134,6 @@ def _test_static_api( exe = paddle.static.Executor(place) res, *res_grad = exe.run(feed=feed, fetch_list=fetch_list) - # convert grad value to bool if dtype is bool - grad_value = 123.0 if dtypes[0] != 'bool' else True np.testing.assert_allclose( res_grad[0], np.ones(x[0].shape) * grad_value )