Skip to content

float8_e4m3fn precision overflow #2219

@jiqing-feng

Description

@jiqing-feng

🐛 Describe the bug

import torch

value = torch.load("clamped_value.pt").to(0)
print(value.to(torch.float8_e4m3fn).to(torch.float16).max())

clamped_value.zip

XPU output:

tensor(nan, device='xpu:0', dtype=torch.float16)

A100 output:

tensor(288., device='cuda:0', dtype=torch.float16)

Versions

[pip3] torch==2.10.0.dev20251026+xpu
[pip3] torchao==0.15.0.dev20251027+xpu
[pip3] torchaudio==2.10.0.dev20251026+xpu
[pip3] torchdata==0.12.0.dev20250220
[pip3] torchvision==0.25.0.dev20251026+xpu
[pip3] triton==3.5.0

Metadata

Metadata

Labels

No labels
No labels

Type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions