Skip to content

Commit aa552ec

Browse files
committed
nvfp4 tensor: refactor weight-only vs dynamic quant
Summary: Refactors `NVFP4Tensor` to use `act_quant_kwargs`, to follow the design of recently added `Float8Tensor`. Note that chose not to use `_choose_quant_func_and_quantize_tensor` as we do not support any activation types other than nvfp4. This can be relaxed in the future if needed. This is still not the final API, might need to make more tweaks before we bring out of prototype. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 61cbdf1 ghstack-comment-id: 3197771544 Pull-Request: #2790
1 parent 82eec4f commit aa552ec

File tree

4 files changed

+79
-38
lines changed

4 files changed

+79
-38
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,8 @@ def test_nvfp4_swizzled_scales_serialization():
916916
tensor_list, ctx = original_tensor.__tensor_flatten__()
917917

918918
# Verify swizzled flag is preserved in context
919-
assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales"
920-
assert ctx[3] == True
919+
assert NVFP4Tensor.tensor_attribute_names[2] == "_is_swizzled_scales"
920+
assert ctx[2] == True
921921

922922
# Test deserialization
923923
inner_tensors = {}

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from torchao.prototype.mx_formats.inference_workflow import (
1616
NVFP4MMConfig,
1717
)
18+
from torchao.prototype.mx_formats.nvfp4_tensor import (
19+
QuantizeTensorToNVFP4Kwargs,
20+
)
1821
from torchao.quantization.utils import compute_error
1922
from torchao.testing.utils import skip_if_rocm
2023
from torchao.utils import (
@@ -304,8 +307,8 @@ def test_nvfp4_swizzled_scales_serialization():
304307
tensor_list, ctx = original_tensor.__tensor_flatten__()
305308

306309
# Verify swizzled flag is preserved in context
307-
assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales"
308-
assert ctx[3] == True
310+
assert NVFP4Tensor.tensor_attribute_names[2] == "_is_swizzled_scales"
311+
assert ctx[2] == True
309312

310313
# Test deserialization
311314
inner_tensors = {}
@@ -491,19 +494,21 @@ def test_nvfp4_matmul_with_amax(
491494

492495
a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
493496
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
497+
act_quant_kwargs = None
498+
if mm_config == NVFP4MMConfig.DYNAMIC:
499+
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs()
494500
A_nvfp4 = NVFP4Tensor.to_nvfp4(
495501
A,
496502
per_tensor_scale=a_scale,
497-
mm_config=mm_config,
498503
is_swizzled_scales=True,
499504
use_triton_kernel=use_triton_kernel,
500505
)
501506
B_nvfp4 = NVFP4Tensor.to_nvfp4(
502507
B,
503508
per_tensor_scale=b_scale,
504-
mm_config=mm_config,
505509
is_swizzled_scales=True,
506510
use_triton_kernel=use_triton_kernel,
511+
act_quant_kwargs=act_quant_kwargs,
507512
)
508513

509514
func = torch.compile(F.linear, fullgraph=True) if compile else F.linear

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
_validate_gemm_kernel_choice,
2020
)
2121
from torchao.prototype.mx_formats.mx_tensor import MXTensor
22-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4MMConfig, NVFP4Tensor
22+
from torchao.prototype.mx_formats.nvfp4_tensor import (
23+
NVFP4MMConfig,
24+
NVFP4Tensor,
25+
QuantizeTensorToNVFP4Kwargs,
26+
)
2327
from torchao.quantization.quant_api import to_linear_activation_quantized
2428
from torchao.quantization.transform_module import (
2529
register_quantize_module_handler,
@@ -199,11 +203,15 @@ def _nvfp4_inference_linear_transform(
199203
"Please use bfloat16 or float16 weights, or remove the bias from the linear layer."
200204
)
201205

206+
act_quant_kwargs = None
207+
if config.mm_config == NVFP4MMConfig.DYNAMIC:
208+
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs()
209+
202210
quantized_weight = NVFP4Tensor.to_nvfp4(
203211
weight,
204-
mm_config=config.mm_config,
205212
is_swizzled_scales=True,
206213
use_triton_kernel=False, # Always use traditional construction for weights
214+
act_quant_kwargs=act_quant_kwargs,
207215
)
208216
# Set triton preference after construction
209217
quantized_weight.use_triton_kernel = config.use_triton_kernel

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import sys
8+
from dataclasses import dataclass
89
from enum import Enum
910
from typing import Any, Dict, Optional
1011

@@ -24,6 +25,9 @@
2425
tensor_size_hp_to_fp4x2,
2526
)
2627
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
28+
from torchao.quantization.quantize_.common import (
29+
QuantizeTensorKwargs,
30+
)
2731
from torchao.utils import TorchAOBaseTensor, ceil_div, fill_defaults
2832

2933
E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny
@@ -38,6 +42,13 @@ class NVFP4MMConfig(Enum):
3842
WEIGHT_ONLY = "weight_only"
3943

4044

45+
@dataclass
46+
class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs):
47+
block_size: int = 16
48+
is_swizzled_scales: bool = False
49+
use_triton_kernel: bool = False
50+
51+
4152
# TODO(future PR): move over to TorchAOBaseTensor's dispatch
4253
def implements(aten_ops):
4354
"""Register aten ops to the NVFP4 op table"""
@@ -60,33 +71,34 @@ class NVFP4Tensor(TorchAOBaseTensor):
6071
qdata: Packed FP4 data (2 values per byte)
6172
_scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
6273
_per_tensor_scale: Optional global per-tensor scale in float32 format
74+
_act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
6375
_block_size (int): Block size for quantization (fixed at 16)
6476
_orig_dtype (torch.dtype): Original tensor dtype before quantization
6577
_is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
66-
mm_config (NVFP4MMConfig): Matrix multiplication configuration
6778
use_triton_kernel (bool): Whether to use triton kernels
6879
"""
6980

7081
tensor_data_names = ["qdata", "_scale_e4m3"]
71-
optional_tensor_data_names = ["_per_tensor_scale"]
82+
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
7283
tensor_attribute_names = [
7384
"_block_size",
7485
"_orig_dtype",
75-
"mm_config",
7686
"_is_swizzled_scales",
7787
"use_triton_kernel",
88+
"act_quant_kwargs",
7889
]
7990

8091
def __new__(
8192
cls,
8293
qdata,
8394
blockwise_scales,
8495
per_tensor_scale,
96+
act_per_tensor_scale,
8597
block_size,
8698
orig_dtype,
87-
mm_config=NVFP4MMConfig.DYNAMIC,
8899
is_swizzled_scales=False,
89100
use_triton_kernel=False,
101+
act_quant_kwargs=None,
90102
):
91103
# FP4 tensor size handling two paths, contiguous or not
92104
new_size = qdata.size()
@@ -107,11 +119,12 @@ def __new__(
107119
self._scale_e4m3 = blockwise_scales
108120
self._is_swizzled_scales = is_swizzled_scales
109121
self._per_tensor_scale = per_tensor_scale
122+
self._act_per_tensor_scale = act_per_tensor_scale
110123
self.qdata = qdata
111124
self._block_size = block_size
112125
self._orig_dtype = orig_dtype
113-
self.mm_config = mm_config
114126
self.use_triton_kernel = use_triton_kernel
127+
self.act_quant_kwargs = act_quant_kwargs
115128
return self
116129

117130
def __repr__(self):
@@ -130,9 +143,10 @@ def to_nvfp4(
130143
data_hp: torch.Tensor,
131144
block_size: int = 16,
132145
per_tensor_scale: Optional[torch.Tensor] = None,
133-
mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC,
146+
act_per_tensor_scale: Optional[torch.Tensor] = None,
134147
is_swizzled_scales: bool = False,
135148
use_triton_kernel: bool = False,
149+
act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None,
136150
):
137151
"""Convert high precision tensor to NVFP4 format.
138152
@@ -141,9 +155,11 @@ def to_nvfp4(
141155
block_size: Block size for quantization (must be 16)
142156
per_tensor_scale: Optional pre-computed absolute maximum for calibration.
143157
If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
144-
mm_config: Matrix multiplication configuration
158+
act_per_tensor_scale: Optional pre-computed absolute maximum for calibration for activation
159+
If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
145160
is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication
146161
use_triton_kernel: If True, use Triton kernel for quantization
162+
act_quant_kwargs: If specified, config for quantizing the activation
147163
148164
Returns:
149165
NVFP4Tensor: Quantized tensor in NVFP4 format
@@ -169,11 +185,12 @@ def to_nvfp4(
169185
data_lp,
170186
blockwise_scales,
171187
per_tensor_scale,
188+
act_per_tensor_scale,
172189
block_size,
173190
data_hp.dtype,
174-
mm_config,
175191
is_swizzled_scales,
176192
use_triton_kernel,
193+
act_quant_kwargs,
177194
)
178195

179196
# Do not force the NVFP4Tensor type on the returned tensor
@@ -244,6 +261,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
244261
per_tensor_scale_equal = (
245262
self._per_tensor_scale is None and src._per_tensor_scale is None
246263
) or (self._per_tensor_scale.shape == src._per_tensor_scale.shape)
264+
act_per_tensor_scale_equal = (
265+
self._act_per_tensor_scale is None and src._act_per_tensor_scale is None
266+
) or (self._act_per_tensor_scale.shape == src._act_per_tensor_scale.shape)
247267

248268
return (
249269
isinstance(self, NVFP4Tensor)
@@ -253,7 +273,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
253273
and self._is_swizzled_scales == src._is_swizzled_scales
254274
and self._scale_e4m3.shape == src._scale_e4m3.shape
255275
and per_tensor_scale_equal
276+
and act_per_tensor_scale_equal
256277
and self.qdata.shape == src.qdata.shape
278+
and self.act_quant_kwargs == src.act_quant_kwargs
257279
)
258280

259281

@@ -290,12 +312,13 @@ def nvfp4_to_copy(func, types, args, kwargs):
290312
res = NVFP4Tensor(
291313
tensor._scale_e4m3,
292314
tensor._per_tensor_scale,
315+
tensor._act_per_tensor_scale,
293316
tensor._data,
294317
tensor._block_size,
295318
dtype,
296-
tensor.mm_config,
297319
tensor._is_swizzled_scales,
298320
tensor.use_triton_kernel,
321+
tensor.act_quant_kwargs,
299322
)
300323
return res
301324

@@ -491,11 +514,12 @@ def nvfp4_slice(func, types, args, kwargs):
491514
sliced_data,
492515
sliced_scale,
493516
x._per_tensor_scale,
517+
x._act_per_tensor_scale,
494518
x._block_size,
495519
x._orig_dtype,
496-
x.mm_config,
497520
x._is_swizzled_scales,
498521
x.use_triton_kernel,
522+
x.act_quant_kwargs,
499523
)
500524

501525
return return_and_correct_aliasing(func, args, kwargs, result)
@@ -509,11 +533,12 @@ def nvfp4_t(func, types, args, kwargs):
509533
old.qdata.t(),
510534
old._scale_e4m3,
511535
old._per_tensor_scale,
536+
old._act_per_tensor_scale,
512537
old._block_size,
513538
old._orig_dtype,
514-
old.mm_config,
515539
old._is_swizzled_scales,
516540
old.use_triton_kernel,
541+
old.act_quant_kwargs,
517542
)
518543
return new
519544

@@ -528,11 +553,12 @@ def nvfp4_view_op(func, types, args, kwargs):
528553
new_data,
529554
args[0]._scale_e4m3,
530555
args[0]._per_tensor_scale,
556+
args[0]._act_per_tensor_scale,
531557
args[0]._block_size,
532558
args[0]._orig_dtype,
533-
args[0].mm_config,
534559
args[0]._is_swizzled_scales,
535560
args[0].use_triton_kernel,
561+
args[0].act_quant_kwargs,
536562
)
537563

538564

@@ -610,17 +636,19 @@ def nvfp4_linear(func, types, args, kwargs):
610636
if not isinstance(weight_tensor, NVFP4Tensor):
611637
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")
612638

613-
config = weight_tensor.mm_config
614-
615-
if config == NVFP4MMConfig.WEIGHT_ONLY:
639+
if weight_tensor.act_quant_kwargs is None:
640+
# weight_only quant
616641
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
617642
return torch.nn.functional.linear(input_tensor, weight_dequant, bias)
618643
else:
644+
# dynamic quant
645+
k = weight_tensor.act_quant_kwargs
619646
input_tensor = NVFP4Tensor.to_nvfp4(
620647
input_tensor,
621-
mm_config=config,
622-
is_swizzled_scales=True,
623-
use_triton_kernel=weight_tensor.use_triton_kernel,
648+
block_size=k.block_size,
649+
per_tensor_scale=weight_tensor._act_per_tensor_scale,
650+
is_swizzled_scales=k.is_swizzled_scales,
651+
use_triton_kernel=k.use_triton_kernel,
624652
)
625653
return _addmm_nvfp4_dispatch(input_tensor, weight_tensor.t(), func, bias=bias)
626654

@@ -632,9 +660,7 @@ def nvfp4_mm(func, types, args, kwargs):
632660
if not isinstance(weight_tensor, NVFP4Tensor):
633661
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")
634662

635-
config = weight_tensor.mm_config
636-
637-
if config == NVFP4MMConfig.WEIGHT_ONLY:
663+
if weight_tensor.act_quant_kwargs is None:
638664
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
639665
if isinstance(input_tensor, NVFP4Tensor):
640666
input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype)
@@ -643,11 +669,13 @@ def nvfp4_mm(func, types, args, kwargs):
643669
return func(input_tensor, weight_dequant)
644670
else:
645671
if not isinstance(input_tensor, NVFP4Tensor):
672+
k = weight_tensor.act_quant_kwargs
646673
input_tensor = NVFP4Tensor.to_nvfp4(
647674
input_tensor,
648-
mm_config=config,
649-
is_swizzled_scales=True,
650-
use_triton_kernel=weight_tensor.use_triton_kernel,
675+
block_size=k.block_size,
676+
per_tensor_scale=weight_tensor._act_per_tensor_scale,
677+
is_swizzled_scales=k.is_swizzled_scales,
678+
use_triton_kernel=k.use_triton_kernel,
651679
)
652680
return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func)
653681

@@ -659,9 +687,7 @@ def nvfp4_addmm(func, types, args, kwargs):
659687
if not isinstance(weight_tensor, NVFP4Tensor):
660688
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")
661689

662-
config = weight_tensor.mm_config
663-
664-
if config == NVFP4MMConfig.WEIGHT_ONLY:
690+
if weight_tensor.act_quant_kwargs is None:
665691
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
666692
if isinstance(input_tensor, NVFP4Tensor):
667693
input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype)
@@ -670,11 +696,13 @@ def nvfp4_addmm(func, types, args, kwargs):
670696
return torch.addmm(bias, input_tensor, weight_dequant)
671697
else:
672698
if not isinstance(input_tensor, NVFP4Tensor):
699+
k = weight_tensor.act_quant_kwargs
673700
input_tensor = NVFP4Tensor.to_nvfp4(
674701
input_tensor,
675-
mm_config=config,
676-
is_swizzled_scales=True,
677-
use_triton_kernel=weight_tensor.use_triton_kernel,
702+
block_size=k.block_size,
703+
per_tensor_scale=weight_tensor._act_per_tensor_scale,
704+
is_swizzled_scales=k.is_swizzled_scales,
705+
use_triton_kernel=k.use_triton_kernel,
678706
)
679707
return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func, bias=bias)
680708

0 commit comments

Comments
 (0)