Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bf16 support for elementwise operation #3462

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, Optional, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt import _enums
Expand All @@ -15,11 +16,10 @@
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
to_torch,
)
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor

import tensorrt as trt


def get_python_op_from_trt_elementwise_op(
trt_op: TRTElementWiseOp,
Expand Down Expand Up @@ -125,10 +125,9 @@ def convert_binary_elementwise(
# dtype but we don't have a way to detect whether it makes sense for the
# scalar to be float or half. Hence we go with the lhs dtype.
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)):
rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype))
rhs_val = to_torch(rhs_val, dtype=lhs_dtype)
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)):
lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype))

lhs_val = to_torch(lhs_val, dtype=rhs_dtype)
lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)

Expand Down
22 changes: 22 additions & 0 deletions tests/py/dynamo/conversion/test_binary_ops_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,28 @@ def forward(self, x, y):
]
self.run_test_with_dynamic_shape(Op(), input_specs)

@parameterized.expand(
[
(f"bf16_{op[0].__name__}_one_constant", op[0])
for op in elementwise_ops
if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"]
]
)
def test_elementwise_ops_bf16(self, _, orig_op):
class TestModule(nn.Module):
def __init__(self, orig_op):
super().__init__()
self.constant = torch.randn(1)
self.orig_op = orig_op

def forward(self, x):
x = self.orig_op(x, self.constant)
return self.orig_op(x, -2)

m = TestModule(orig_op)
inputs = [torch.randn(2, 2, dtype=torch.bfloat16)]
self.run_test(m, inputs)


if __name__ == "__main__":
run_tests()
Loading