-
Notifications
You must be signed in to change notification settings - Fork 369
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
The matmul operator implementation includes unnecessary type casting for inputs when use_explicit_typing=True, which is not the correct solution for handling dtype mismatches in PyTorch's torch.matmul operation.
https://github.com/pytorch/TensorRT/blame/dad195b92ecedd21ca8785f9643472c56b96b502/py/torch_tensorrt/dynamo/conversion/impl/matmul.py#L51](https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/impl/matmul.py#L51-L60)
Parameters for torch.matmul accepts only same dtype.
This approach masks the underlying issue rather than addressing it. we need right fix for root cause.
>>> a = torch.randn(3, dtype=torch.float32)
>>> b = torch.randn(3, dtype=torch.float16)
>>> torch.matmul(a,b )
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Half
To Reproduce
Steps to reproduce the behavior:
import torch
import torch_tensorrt as torchtrt
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16)
model = pipe.unet
model_torch = model.eval()
model_torch = model_torch.half().to('cuda')
input_tensors = [
torch.randn((1, 4, 64, 64), dtype=torch.float16).cuda(),
torch.Tensor([1.0]).cuda(),
torch.randn((1, 1, 768), dtype=torch.float16).cuda(),
]
model_trt = torchtrt.compile(
model_torch,
inputs=input_tensors,
ir="dynamo",
cache_built_engines=False,
reuse_cached_engines=False,
min_block_size=1,
debug=False,
use_python_runtime=True,
use_explicit_typing=True,
truncate_long_and_double=True,
)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working