Skip to content

fix: Remove type casting in matmul and add scalar tensor conversion #3713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

keehyuna
Copy link
Collaborator

Description

Compiling Stable Diffusion UNet model with use_explicit_typing=True caused dtype mismatches in matmul operations due to incorrect dtype propagation from elementwise operations, specifically aten.div.Tensor operations.

Problem Chain:

  1. Elementwise Operation: aten.div.Tensor with mixed dtypes (float16 tensor / float32 scalar)
  2. Incorrect Dtype Propagation: TensorRT incorrectly promoted output to float32 instead of maintaining float16
  3. MatMul Impact: The incorrectly typed output propagated to subsequent matmul operations
  4. Type Mismatch: MatMul received inputs with incompatible dtypes, causing compilation failures
torch_tensorrt.dynamo.conversion._TRTInterpreter[INFO]Converted node down_blocks.0.attentions.0.transformer_blocks.0.attn1/div_2 [aten.div.Tensor] (Inputs: (_reshape_copy_12: (1, 8, 4096, 4096)@torch.float16, _frozen_param8: ()@torch.float32) | Outputs: (div_2: (1, 8, 4096, 4096)@torch.float16))

in debugging log, output of (float16 tensor/float32 scalar) is float16. But dtype of output TRTTensor is float32.

This precision promotion is different from torch. Proposed fix is to convert torch scalar to value then scalar operands adopt the tensor's dtype
https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py#L134-L139

  • TensorRT: float16 tensor / float32 scalar → float32 output
  • PyTorch: float16 tensor / float32 scalar → float16 output
>>> y = torch.tensor([ 0.3810,  1.2774, -0.2972, -0.3719,  0.4637], dtype=torch.float16)
>>> value = torch.tensor(0.4, dtype=torch.float)
>>> y/value
tensor([ 0.9526,  3.1934, -0.7427, -0.9297,  1.1592], dtype=torch.float16)
>>> x = 0.5
>>> k = torch.tensor([ 0.3810,  1.2774, -0.2972, -0.3719,  0.4637], dtype=torch.float32)
>>> (y/k).dtype
torch.float32

Fixes #3712

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@keehyuna keehyuna self-assigned this Jul 24, 2025
@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jul 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🐛 [Bug] Incorrect Type Casting in MatMul Operator
2 participants