Skip to content

Commit 721a0ce

Browse files
committed
bf16 support
1 parent ca59597 commit 721a0ce

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
get_trt_tensor,
1616
has_dynamic_shape,
1717
set_layer_name,
18+
to_torch,
1819
)
1920
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor
20-
2121
import tensorrt as trt
2222

23-
2423
def get_python_op_from_trt_elementwise_op(
2524
trt_op: TRTElementWiseOp,
2625
) -> Callable[[Any, Any], Any]:
@@ -125,10 +124,9 @@ def convert_binary_elementwise(
125124
# dtype but we don't have a way to detect whether it makes sense for the
126125
# scalar to be float or half. Hence we go with the lhs dtype.
127126
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)):
128-
rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype))
127+
rhs_val = to_torch(rhs_val, dtype=lhs_dtype)
129128
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)):
130-
lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype))
131-
129+
lhs_val = to_torch(lhs_val, dtype=rhs_dtype)
132130
lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
133131
rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)
134132

tests/py/dynamo/conversion/test_binary_ops_aten.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import torch
44
import torch.nn as nn
5+
from .harness import DispatchTestCase
56
from parameterized import parameterized
67
from torch.testing._internal.common_utils import run_tests
78
from torch_tensorrt import Input
89

9-
from .harness import DispatchTestCase
10-
1110
NEED_TEST_BOTH_CONSTANTS_CASE = True
1211

1312
elementwise_ops = [
@@ -228,6 +227,28 @@ def forward(self, x, y):
228227
]
229228
self.run_test_with_dynamic_shape(Op(), input_specs)
230229

230+
@parameterized.expand(
231+
[
232+
(f"bf16_{op[0].__name__}_one_constant", op[0])
233+
for op in elementwise_ops
234+
if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"]
235+
]
236+
)
237+
def test_elementwise_ops_bf16(self, _, orig_op):
238+
class TestModule(nn.Module):
239+
def __init__(self, orig_op):
240+
super().__init__()
241+
self.constant = torch.randn(1)
242+
self.orig_op = orig_op
243+
244+
def forward(self, x):
245+
x = self.orig_op(x, self.constant)
246+
return self.orig_op(x, -2)
247+
248+
m = TestModule(orig_op)
249+
inputs = [torch.randn(2, 2, dtype=torch.bfloat16)]
250+
self.run_test(m, inputs)
251+
231252

232253
if __name__ == "__main__":
233254
run_tests()

0 commit comments

Comments
 (0)