Skip to content

🐛 [Bug] Incorrect Type Casting in MatMul Operator #3712

@keehyuna

Description

@keehyuna

Bug Description

The matmul operator implementation includes unnecessary type casting for inputs when use_explicit_typing=True, which is not the correct solution for handling dtype mismatches in PyTorch's torch.matmul operation.

https://github.com/pytorch/TensorRT/blame/dad195b92ecedd21ca8785f9643472c56b96b502/py/torch_tensorrt/dynamo/conversion/impl/matmul.py#L51](https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/impl/matmul.py#L51-L60)
Parameters for torch.matmul accepts only same dtype.
This approach masks the underlying issue rather than addressing it. we need right fix for root cause.

>>> a = torch.randn(3, dtype=torch.float32)
>>> b = torch.randn(3, dtype=torch.float16)
>>> torch.matmul(a,b )
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Half

To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt as torchtrt
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16)

model = pipe.unet
model_torch = model.eval()
model_torch = model_torch.half().to('cuda')
input_tensors = [
    torch.randn((1, 4, 64, 64), dtype=torch.float16).cuda(),
    torch.Tensor([1.0]).cuda(),
    torch.randn((1, 1, 768), dtype=torch.float16).cuda(),
]

model_trt = torchtrt.compile(
    model_torch,
    inputs=input_tensors,
    ir="dynamo",
    cache_built_engines=False,
    reuse_cached_engines=False,
    min_block_size=1,
    debug=False,
    use_python_runtime=True,
    use_explicit_typing=True,
    truncate_long_and_double=True,
)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions