diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 4a4b33abea..ab9629b0db 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Optional, Union import numpy as np +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt import _enums @@ -15,11 +16,10 @@ get_trt_tensor, has_dynamic_shape, set_layer_name, + to_torch, ) from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor -import tensorrt as trt - def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, @@ -125,10 +125,9 @@ def convert_binary_elementwise( # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)): - rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype)) + rhs_val = to_torch(rhs_val, dtype=lhs_dtype) if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)): - lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype)) - + lhs_val = to_torch(lhs_val, dtype=rhs_dtype) lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) diff --git a/tests/py/dynamo/conversion/test_binary_ops_aten.py b/tests/py/dynamo/conversion/test_binary_ops_aten.py index 79c0d9430a..ac8cf4b00b 100644 --- a/tests/py/dynamo/conversion/test_binary_ops_aten.py +++ b/tests/py/dynamo/conversion/test_binary_ops_aten.py @@ -228,6 +228,28 @@ def forward(self, x, y): ] self.run_test_with_dynamic_shape(Op(), input_specs) + @parameterized.expand( + [ + (f"bf16_{op[0].__name__}_one_constant", op[0]) + for op in elementwise_ops + if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"] + ] + ) + def test_elementwise_ops_bf16(self, _, orig_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant = torch.randn(1) + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x, self.constant) + return self.orig_op(x, -2) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2, dtype=torch.bfloat16)] + self.run_test(m, inputs) + if __name__ == "__main__": run_tests()