Skip to content

Commit fee314b

Browse files
authored
mxtensor: refactor activation quant to use direct logic (#2806)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 1a20585 commit fee314b

File tree

5 files changed

+56
-47
lines changed

5 files changed

+56
-47
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
390390
use_fp4_custom_triton_dequant_kernel,
391391
MXGemmKernelChoice.EMULATED,
392392
pack_fp6,
393+
None,
393394
)
394395
tensor_hp = tensor_mx.to_dtype(torch.float)
395396
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import types
88
from dataclasses import dataclass
9-
from typing import Optional
109

1110
import torch
1211

@@ -18,13 +17,12 @@
1817
_validate_elem_dtype,
1918
_validate_gemm_kernel_choice,
2019
)
21-
from torchao.prototype.mx_formats.mx_tensor import MXTensor
20+
from torchao.prototype.mx_formats.mx_tensor import MXTensor, QuantizeTensorToMXKwargs
2221
from torchao.prototype.mx_formats.nvfp4_tensor import (
2322
NVFP4MMConfig,
2423
NVFP4Tensor,
2524
QuantizeTensorToNVFP4Kwargs,
2625
)
27-
from torchao.quantization.quant_api import to_linear_activation_quantized
2826
from torchao.quantization.transform_module import (
2927
register_quantize_module_handler,
3028
)
@@ -93,26 +91,6 @@ def _linear_extra_repr(self):
9391
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}"
9492

9593

96-
def _input_activation_quant_func_mxfp(
97-
x: torch.Tensor,
98-
activation_dtype: torch.dtype,
99-
block_size: int,
100-
scale: Optional[torch.Tensor] = None,
101-
):
102-
""" """
103-
104-
# TODO scale for static quant
105-
106-
activation = MXTensor.to_mx(
107-
x,
108-
activation_dtype,
109-
block_size=block_size,
110-
gemm_kernel_choice=None, # Get from weight
111-
pack_fp6=False, # TODO
112-
)
113-
return activation
114-
115-
11694
@register_quantize_module_handler(MXFPInferenceConfig)
11795
def _mx_inference_linear_transform(
11896
module: torch.nn.Module, config: MXFPInferenceConfig
@@ -121,32 +99,26 @@ def _mx_inference_linear_transform(
12199
# TODO handle AMD
122100
assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"
123101

124-
activation_dtype = config.activation_dtype
125-
weight_dtype = config.weight_dtype
126102
weight = module.weight
127103

128104
assert weight.dtype == torch.bfloat16, (
129105
f"Only supporting bf16 out dtype for now, got {weight.dtype}"
130106
)
107+
act_quant_kwargs = QuantizeTensorToMXKwargs(
108+
elem_dtype=config.activation_dtype,
109+
block_size=config.block_size,
110+
gemm_kernel_choice=config.gemm_kernel_choice,
111+
pack_fp6=False,
112+
)
131113

132114
# Convert weight to MX Tensor
133115
quantized_weight = MXTensor.to_mx(
134116
weight,
135-
weight_dtype,
117+
config.weight_dtype,
136118
block_size=config.block_size,
137119
gemm_kernel_choice=config.gemm_kernel_choice,
138120
pack_fp6=False, # TODO
139-
)
140-
141-
input_quant_func = _input_activation_quant_func_mxfp
142-
input_quant_kwargs = {
143-
"block_size": config.block_size,
144-
"activation_dtype": activation_dtype,
145-
"scale": None,
146-
}
147-
148-
quantized_weight = to_linear_activation_quantized(
149-
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
121+
act_quant_kwargs=act_quant_kwargs,
150122
)
151123

152124
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
@@ -226,7 +198,6 @@ def _nvfp4_inference_linear_transform(
226198
NVFP4Tensor,
227199
NVFP4MMConfig,
228200
MXGemmKernelChoice,
229-
_input_activation_quant_func_mxfp,
230201
]
231202
)
232203

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def _to_mxfp8_dim1_kernel_wrapper(
6868
False,
6969
gemm_kernel_choice,
7070
False,
71+
None,
7172
)
7273
mx_tensor = DTensor.from_local(
7374
inner,
@@ -87,6 +88,7 @@ def _to_mxfp8_dim1_kernel_wrapper(
8788
False,
8889
gemm_kernel_choice,
8990
False,
91+
None,
9092
)
9193
return mx_tensor
9294

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,26 @@ def _get_gemm_choice(
8080

8181

8282
def _addmm_mx_dispatch(
83-
a: MXTensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None
83+
a: torch.Tensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None
8484
) -> torch.Tensor:
8585
"""
8686
Core implementation shared between mx_mm and mx_addmm.
8787
The only difference is whether bias is None or not.
8888
"""
89+
90+
if not isinstance(a, MXTensor):
91+
assert b.act_quant_kwargs is not None, "weight-only quant not yet supported"
92+
k = b.act_quant_kwargs
93+
a = MXTensor.to_mx(
94+
a,
95+
k.elem_dtype,
96+
k.block_size,
97+
k.scaling_mode,
98+
k.use_fp4_custom_triton_dequant_kernel,
99+
k.gemm_kernel_choice,
100+
k.pack_fp6,
101+
)
102+
89103
gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice)
90104

91105
if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
@@ -148,18 +162,14 @@ def _addmm_mx_dispatch(
148162
def mx_mm(func, types, args, kwargs):
149163
a = args[0]
150164
b = args[1]
151-
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
165+
assert isinstance(b, MXTensor)
152166

153167
return _addmm_mx_dispatch(a, b, func)
154168

155169

156170
@implements([aten.addmm.default])
157171
def mx_addmm(func, types, args, kwargs):
158-
assert (
159-
isinstance(args[0], torch.Tensor)
160-
and isinstance(args[1], MXTensor)
161-
and isinstance(args[2], MXTensor)
162-
)
172+
assert isinstance(args[0], torch.Tensor) and isinstance(args[2], MXTensor)
163173
bias = args[0]
164174
a = args[1]
165175
b = args[2]
@@ -179,6 +189,7 @@ def mx_t(func, types, args, kwargs):
179189
old._use_fp4_custom_triton_dequant_kernel,
180190
old._gemm_kernel_choice,
181191
old._pack_fp6,
192+
old.act_quant_kwargs,
182193
)
183194
return new
184195

@@ -223,6 +234,7 @@ def mx_view_op(func, types, args, kwargs):
223234
args[0]._use_fp4_custom_triton_dequant_kernel,
224235
args[0]._gemm_kernel_choice,
225236
args[0]._pack_fp6,
237+
args[0].act_quant_kwargs,
226238
)
227239

228240

@@ -284,6 +296,7 @@ def mx_slice(func, types, args, kwargs):
284296
x._use_fp4_custom_triton_dequant_kernel,
285297
x._gemm_kernel_choice,
286298
x._pack_fp6,
299+
x.act_quant_kwargs,
287300
),
288301
)
289302

@@ -338,6 +351,7 @@ def autocast_to_copy(func, types, args, kwargs):
338351
tensor._use_fp4_custom_triton_dequant_kernel,
339352
tensor._gemm_kernel_choice,
340353
tensor._pack_fp6,
354+
tensor.act_quant_kwargs,
341355
)
342356
return res
343357

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
* Zeros: N/A
1818
"""
1919

20-
from typing import Union
20+
from dataclasses import dataclass
21+
from typing import Optional, Union
2122

2223
import torch
2324
from torch.distributed._tensor import DTensor
@@ -57,6 +58,9 @@
5758
triton_f6_e3m2_to_scaled_bf16,
5859
unpack_uint4,
5960
)
61+
from torchao.quantization.quantize_.common import (
62+
QuantizeTensorKwargs,
63+
)
6064
from torchao.utils import TorchAOBaseTensor
6165

6266
# TODO(later): read from somewhere else?
@@ -68,6 +72,16 @@
6872
EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2
6973

7074

75+
@dataclass
76+
class QuantizeTensorToMXKwargs(QuantizeTensorKwargs):
77+
elem_dtype: Union[torch.dtype, str] = torch.float8_e4m3fn
78+
block_size: int = 32
79+
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR
80+
use_fp4_custom_triton_dequant_kernel: bool = False
81+
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
82+
pack_fp6: bool = False
83+
84+
7185
def _to_mx_rceil(
7286
data_hp: torch.Tensor,
7387
max_abs: torch.Tensor,
@@ -458,6 +472,7 @@ class MXTensor(TorchAOBaseTensor):
458472
"_use_fp4_custom_triton_dequant_kernel",
459473
"_gemm_kernel_choice",
460474
"_pack_fp6",
475+
"act_quant_kwargs",
461476
]
462477

463478
def __new__(
@@ -470,6 +485,7 @@ def __new__(
470485
use_fp4_custom_triton_dequant_kernel,
471486
gemm_kernel_choice,
472487
pack_fp6,
488+
act_quant_kwargs,
473489
):
474490
new_size = qdata.size()
475491
if elem_dtype == torch.float4_e2m1fn_x2:
@@ -540,11 +556,12 @@ def __new__(
540556
)
541557
self._gemm_kernel_choice = gemm_kernel_choice
542558
self._pack_fp6 = pack_fp6
559+
self.act_quant_kwargs = act_quant_kwargs
543560
return self
544561

545562
def __repr__(self):
546563
# TODO better elem dtype print for fp4
547-
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}" # noqa: E501
564+
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501
548565

549566
@classmethod
550567
def __torch_dispatch__(cls, func, types, args, kwargs=None):
@@ -582,8 +599,10 @@ def to_mx(
582599
block_size: int = BLOCK_SIZE_DEFAULT,
583600
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
584601
use_fp4_custom_triton_dequant_kernel: bool = False,
602+
# TODO(future PR): switch default gemm to cublas
585603
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED,
586604
pack_fp6: bool = False,
605+
act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None,
587606
):
588607
scale_e8m0_biased, data_lp = to_mx(
589608
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6
@@ -601,6 +620,7 @@ def to_mx(
601620
use_fp4_custom_triton_dequant_kernel,
602621
gemm_kernel_choice,
603622
pack_fp6,
623+
act_quant_kwargs,
604624
)
605625
return DTensor.from_local(
606626
inner_mx_tensor,
@@ -619,6 +639,7 @@ def to_mx(
619639
use_fp4_custom_triton_dequant_kernel,
620640
gemm_kernel_choice,
621641
pack_fp6,
642+
act_quant_kwargs,
622643
)
623644

624645
# Do not force the MXTensor type on the returned tensor

0 commit comments

Comments
 (0)