1717 * Zeros: N/A
1818"""
1919
20- from typing import Union
20+ from dataclasses import dataclass
21+ from typing import Optional , Union
2122
2223import torch
2324from torch .distributed ._tensor import DTensor
5758 triton_f6_e3m2_to_scaled_bf16 ,
5859 unpack_uint4 ,
5960)
61+ from torchao .quantization .quantize_ .common import (
62+ QuantizeTensorKwargs ,
63+ )
6064from torchao .utils import TorchAOBaseTensor
6165
6266# TODO(later): read from somewhere else?
6872EBITS_F8_E5M2 , MBITS_F8_E5M2 = 5 , 2
6973
7074
75+ @dataclass
76+ class QuantizeTensorToMXKwargs (QuantizeTensorKwargs ):
77+ elem_dtype : Union [torch .dtype , str ] = torch .float8_e4m3fn
78+ block_size : int = 32
79+ scaling_mode : ScaleCalculationMode = ScaleCalculationMode .FLOOR
80+ use_fp4_custom_triton_dequant_kernel : bool = False
81+ gemm_kernel_choice : MXGemmKernelChoice = MXGemmKernelChoice .EMULATED
82+ pack_fp6 : bool = False
83+
84+
7185def _to_mx_rceil (
7286 data_hp : torch .Tensor ,
7387 max_abs : torch .Tensor ,
@@ -458,6 +472,7 @@ class MXTensor(TorchAOBaseTensor):
458472 "_use_fp4_custom_triton_dequant_kernel" ,
459473 "_gemm_kernel_choice" ,
460474 "_pack_fp6" ,
475+ "act_quant_kwargs" ,
461476 ]
462477
463478 def __new__ (
@@ -470,6 +485,7 @@ def __new__(
470485 use_fp4_custom_triton_dequant_kernel ,
471486 gemm_kernel_choice ,
472487 pack_fp6 ,
488+ act_quant_kwargs ,
473489 ):
474490 new_size = qdata .size ()
475491 if elem_dtype == torch .float4_e2m1fn_x2 :
@@ -540,11 +556,12 @@ def __new__(
540556 )
541557 self ._gemm_kernel_choice = gemm_kernel_choice
542558 self ._pack_fp6 = pack_fp6
559+ self .act_quant_kwargs = act_quant_kwargs
543560 return self
544561
545562 def __repr__ (self ):
546563 # TODO better elem dtype print for fp4
547- return f"MXTensor: elem_dtype: { self ._elem_dtype } , s_e8m0: { self ._scale_e8m0 } , d: { self .qdata } , d_hp : { self .to_dtype ( self . _orig_dtype ) } " # noqa: E501
564+ return f"MXTensor: elem_dtype: { self ._elem_dtype } , s_e8m0: { self ._scale_e8m0 } , d: { self .qdata } , act_quant_kwargs : { self .act_quant_kwargs } " # noqa: E501
548565
549566 @classmethod
550567 def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
@@ -582,8 +599,10 @@ def to_mx(
582599 block_size : int = BLOCK_SIZE_DEFAULT ,
583600 scaling_mode : ScaleCalculationMode = ScaleCalculationMode .FLOOR ,
584601 use_fp4_custom_triton_dequant_kernel : bool = False ,
602+ # TODO(future PR): switch default gemm to cublas
585603 gemm_kernel_choice : MXGemmKernelChoice = MXGemmKernelChoice .EMULATED ,
586604 pack_fp6 : bool = False ,
605+ act_quant_kwargs : Optional [QuantizeTensorToMXKwargs ] = None ,
587606 ):
588607 scale_e8m0_biased , data_lp = to_mx (
589608 data_hp , elem_dtype , block_size , scaling_mode , pack_fp6
@@ -601,6 +620,7 @@ def to_mx(
601620 use_fp4_custom_triton_dequant_kernel ,
602621 gemm_kernel_choice ,
603622 pack_fp6 ,
623+ act_quant_kwargs ,
604624 )
605625 return DTensor .from_local (
606626 inner_mx_tensor ,
@@ -619,6 +639,7 @@ def to_mx(
619639 use_fp4_custom_triton_dequant_kernel ,
620640 gemm_kernel_choice ,
621641 pack_fp6 ,
642+ act_quant_kwargs ,
622643 )
623644
624645 # Do not force the MXTensor type on the returned tensor
0 commit comments