55# LICENSE file in the root directory of this source tree. 
66
77import  sys 
8+ from  dataclasses  import  dataclass 
89from  enum  import  Enum 
910from  typing  import  Any , Dict , Optional 
1011
2425    tensor_size_hp_to_fp4x2 ,
2526)
2627from  torchao .prototype .mx_formats .utils  import  from_blocked , to_blocked 
28+ from  torchao .quantization .quantize_ .common  import  (
29+     QuantizeTensorKwargs ,
30+ )
2731from  torchao .utils  import  TorchAOBaseTensor , ceil_div , fill_defaults 
2832
2933E4M3_EPS  =  torch .finfo (torch .float8_e4m3fn ).tiny 
@@ -38,6 +42,13 @@ class NVFP4MMConfig(Enum):
3842    WEIGHT_ONLY  =  "weight_only" 
3943
4044
45+ @dataclass  
46+ class  QuantizeTensorToNVFP4Kwargs (QuantizeTensorKwargs ):
47+     block_size : int  =  16 
48+     is_swizzled_scales : bool  =  False 
49+     use_triton_kernel : bool  =  False 
50+ 
51+ 
4152# TODO(future PR): move over to TorchAOBaseTensor's dispatch 
4253def  implements (aten_ops ):
4354    """Register aten ops to the NVFP4 op table""" 
@@ -60,33 +71,34 @@ class NVFP4Tensor(TorchAOBaseTensor):
6071        qdata: Packed FP4 data (2 values per byte) 
6172        _scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled) 
6273        _per_tensor_scale: Optional global per-tensor scale in float32 format 
74+         _act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation 
6375        _block_size (int): Block size for quantization (fixed at 16) 
6476        _orig_dtype (torch.dtype): Original tensor dtype before quantization 
6577        _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format 
66-         mm_config (NVFP4MMConfig): Matrix multiplication configuration 
6778        use_triton_kernel (bool): Whether to use triton kernels 
6879    """ 
6980
7081    tensor_data_names  =  ["qdata" , "_scale_e4m3" ]
71-     optional_tensor_data_names  =  ["_per_tensor_scale" ]
82+     optional_tensor_data_names  =  ["_per_tensor_scale" ,  "_act_per_tensor_scale" ]
7283    tensor_attribute_names  =  [
7384        "_block_size" ,
7485        "_orig_dtype" ,
75-         "mm_config" ,
7686        "_is_swizzled_scales" ,
7787        "use_triton_kernel" ,
88+         "act_quant_kwargs" ,
7889    ]
7990
8091    def  __new__ (
8192        cls ,
8293        qdata ,
8394        blockwise_scales ,
8495        per_tensor_scale ,
96+         act_per_tensor_scale ,
8597        block_size ,
8698        orig_dtype ,
87-         mm_config = NVFP4MMConfig .DYNAMIC ,
8899        is_swizzled_scales = False ,
89100        use_triton_kernel = False ,
101+         act_quant_kwargs = None ,
90102    ):
91103        # FP4 tensor size handling two paths, contiguous or not 
92104        new_size  =  qdata .size ()
@@ -107,11 +119,12 @@ def __new__(
107119        self ._scale_e4m3  =  blockwise_scales 
108120        self ._is_swizzled_scales  =  is_swizzled_scales 
109121        self ._per_tensor_scale  =  per_tensor_scale 
122+         self ._act_per_tensor_scale  =  act_per_tensor_scale 
110123        self .qdata  =  qdata 
111124        self ._block_size  =  block_size 
112125        self ._orig_dtype  =  orig_dtype 
113-         self .mm_config  =  mm_config 
114126        self .use_triton_kernel  =  use_triton_kernel 
127+         self .act_quant_kwargs  =  act_quant_kwargs 
115128        return  self 
116129
117130    def  __repr__ (self ):
@@ -130,9 +143,10 @@ def to_nvfp4(
130143        data_hp : torch .Tensor ,
131144        block_size : int  =  16 ,
132145        per_tensor_scale : Optional [torch .Tensor ] =  None ,
133-         mm_config :  NVFP4MMConfig  =  NVFP4MMConfig . DYNAMIC ,
146+         act_per_tensor_scale :  Optional [ torch . Tensor ]  =  None ,
134147        is_swizzled_scales : bool  =  False ,
135148        use_triton_kernel : bool  =  False ,
149+         act_quant_kwargs : Optional [QuantizeTensorToNVFP4Kwargs ] =  None ,
136150    ):
137151        """Convert high precision tensor to NVFP4 format. 
138152
@@ -141,9 +155,11 @@ def to_nvfp4(
141155            block_size: Block size for quantization (must be 16) 
142156            per_tensor_scale: Optional pre-computed absolute maximum for calibration. 
143157                If provided, uses per-tensor scaling. If None, uses block-wise scaling only. 
144-             mm_config: Matrix multiplication configuration 
158+             act_per_tensor_scale: Optional pre-computed absolute maximum for calibration for activation 
159+                 If provided, uses per-tensor scaling. If None, uses block-wise scaling only. 
145160            is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication 
146161            use_triton_kernel: If True, use Triton kernel for quantization 
162+             act_quant_kwargs: If specified, config for quantizing the activation 
147163
148164        Returns: 
149165            NVFP4Tensor: Quantized tensor in NVFP4 format 
@@ -169,11 +185,12 @@ def to_nvfp4(
169185            data_lp ,
170186            blockwise_scales ,
171187            per_tensor_scale ,
188+             act_per_tensor_scale ,
172189            block_size ,
173190            data_hp .dtype ,
174-             mm_config ,
175191            is_swizzled_scales ,
176192            use_triton_kernel ,
193+             act_quant_kwargs ,
177194        )
178195
179196    # Do not force the NVFP4Tensor type on the returned tensor 
@@ -244,6 +261,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
244261        per_tensor_scale_equal  =  (
245262            self ._per_tensor_scale  is  None  and  src ._per_tensor_scale  is  None 
246263        ) or  (self ._per_tensor_scale .shape  ==  src ._per_tensor_scale .shape )
264+         act_per_tensor_scale_equal  =  (
265+             self ._act_per_tensor_scale  is  None  and  src ._act_per_tensor_scale  is  None 
266+         ) or  (self ._act_per_tensor_scale .shape  ==  src ._act_per_tensor_scale .shape )
247267
248268        return  (
249269            isinstance (self , NVFP4Tensor )
@@ -253,7 +273,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
253273            and  self ._is_swizzled_scales  ==  src ._is_swizzled_scales 
254274            and  self ._scale_e4m3 .shape  ==  src ._scale_e4m3 .shape 
255275            and  per_tensor_scale_equal 
276+             and  act_per_tensor_scale_equal 
256277            and  self .qdata .shape  ==  src .qdata .shape 
278+             and  self .act_quant_kwargs  ==  src .act_quant_kwargs 
257279        )
258280
259281
@@ -290,12 +312,13 @@ def nvfp4_to_copy(func, types, args, kwargs):
290312        res  =  NVFP4Tensor (
291313            tensor ._scale_e4m3 ,
292314            tensor ._per_tensor_scale ,
315+             tensor ._act_per_tensor_scale ,
293316            tensor ._data ,
294317            tensor ._block_size ,
295318            dtype ,
296-             tensor .mm_config ,
297319            tensor ._is_swizzled_scales ,
298320            tensor .use_triton_kernel ,
321+             tensor .act_quant_kwargs ,
299322        )
300323        return  res 
301324
@@ -491,11 +514,12 @@ def nvfp4_slice(func, types, args, kwargs):
491514        sliced_data ,
492515        sliced_scale ,
493516        x ._per_tensor_scale ,
517+         x ._act_per_tensor_scale ,
494518        x ._block_size ,
495519        x ._orig_dtype ,
496-         x .mm_config ,
497520        x ._is_swizzled_scales ,
498521        x .use_triton_kernel ,
522+         x .act_quant_kwargs ,
499523    )
500524
501525    return  return_and_correct_aliasing (func , args , kwargs , result )
@@ -509,11 +533,12 @@ def nvfp4_t(func, types, args, kwargs):
509533        old .qdata .t (),
510534        old ._scale_e4m3 ,
511535        old ._per_tensor_scale ,
536+         old ._act_per_tensor_scale ,
512537        old ._block_size ,
513538        old ._orig_dtype ,
514-         old .mm_config ,
515539        old ._is_swizzled_scales ,
516540        old .use_triton_kernel ,
541+         old .act_quant_kwargs ,
517542    )
518543    return  new 
519544
@@ -528,11 +553,12 @@ def nvfp4_view_op(func, types, args, kwargs):
528553        new_data ,
529554        args [0 ]._scale_e4m3 ,
530555        args [0 ]._per_tensor_scale ,
556+         args [0 ]._act_per_tensor_scale ,
531557        args [0 ]._block_size ,
532558        args [0 ]._orig_dtype ,
533-         args [0 ].mm_config ,
534559        args [0 ]._is_swizzled_scales ,
535560        args [0 ].use_triton_kernel ,
561+         args [0 ].act_quant_kwargs ,
536562    )
537563
538564
@@ -610,17 +636,19 @@ def nvfp4_linear(func, types, args, kwargs):
610636    if  not  isinstance (weight_tensor , NVFP4Tensor ):
611637        raise  NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
612638
613-     config  =  weight_tensor .mm_config 
614- 
615-     if  config  ==  NVFP4MMConfig .WEIGHT_ONLY :
639+     if  weight_tensor .act_quant_kwargs  is  None :
640+         # weight_only quant 
616641        weight_dequant  =  weight_tensor .to_dtype (weight_tensor ._orig_dtype )
617642        return  torch .nn .functional .linear (input_tensor , weight_dequant , bias )
618643    else :
644+         # dynamic quant 
645+         k  =  weight_tensor .act_quant_kwargs 
619646        input_tensor  =  NVFP4Tensor .to_nvfp4 (
620647            input_tensor ,
621-             mm_config = config ,
622-             is_swizzled_scales = True ,
623-             use_triton_kernel = weight_tensor .use_triton_kernel ,
648+             block_size = k .block_size ,
649+             per_tensor_scale = weight_tensor ._act_per_tensor_scale ,
650+             is_swizzled_scales = k .is_swizzled_scales ,
651+             use_triton_kernel = k .use_triton_kernel ,
624652        )
625653        return  _addmm_nvfp4_dispatch (input_tensor , weight_tensor .t (), func , bias = bias )
626654
@@ -632,9 +660,7 @@ def nvfp4_mm(func, types, args, kwargs):
632660    if  not  isinstance (weight_tensor , NVFP4Tensor ):
633661        raise  NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
634662
635-     config  =  weight_tensor .mm_config 
636- 
637-     if  config  ==  NVFP4MMConfig .WEIGHT_ONLY :
663+     if  weight_tensor .act_quant_kwargs  is  None :
638664        weight_dequant  =  weight_tensor .to_dtype (weight_tensor ._orig_dtype )
639665        if  isinstance (input_tensor , NVFP4Tensor ):
640666            input_dequant  =  input_tensor .to_dtype (input_tensor ._orig_dtype )
@@ -643,11 +669,13 @@ def nvfp4_mm(func, types, args, kwargs):
643669            return  func (input_tensor , weight_dequant )
644670    else :
645671        if  not  isinstance (input_tensor , NVFP4Tensor ):
672+             k  =  weight_tensor .act_quant_kwargs 
646673            input_tensor  =  NVFP4Tensor .to_nvfp4 (
647674                input_tensor ,
648-                 mm_config = config ,
649-                 is_swizzled_scales = True ,
650-                 use_triton_kernel = weight_tensor .use_triton_kernel ,
675+                 block_size = k .block_size ,
676+                 per_tensor_scale = weight_tensor ._act_per_tensor_scale ,
677+                 is_swizzled_scales = k .is_swizzled_scales ,
678+                 use_triton_kernel = k .use_triton_kernel ,
651679            )
652680        return  _addmm_nvfp4_dispatch (input_tensor , weight_tensor , func )
653681
@@ -659,9 +687,7 @@ def nvfp4_addmm(func, types, args, kwargs):
659687    if  not  isinstance (weight_tensor , NVFP4Tensor ):
660688        raise  NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
661689
662-     config  =  weight_tensor .mm_config 
663- 
664-     if  config  ==  NVFP4MMConfig .WEIGHT_ONLY :
690+     if  weight_tensor .act_quant_kwargs  is  None :
665691        weight_dequant  =  weight_tensor .to_dtype (weight_tensor ._orig_dtype )
666692        if  isinstance (input_tensor , NVFP4Tensor ):
667693            input_dequant  =  input_tensor .to_dtype (input_tensor ._orig_dtype )
@@ -670,11 +696,13 @@ def nvfp4_addmm(func, types, args, kwargs):
670696            return  torch .addmm (bias , input_tensor , weight_dequant )
671697    else :
672698        if  not  isinstance (input_tensor , NVFP4Tensor ):
699+             k  =  weight_tensor .act_quant_kwargs 
673700            input_tensor  =  NVFP4Tensor .to_nvfp4 (
674701                input_tensor ,
675-                 mm_config = config ,
676-                 is_swizzled_scales = True ,
677-                 use_triton_kernel = weight_tensor .use_triton_kernel ,
702+                 block_size = k .block_size ,
703+                 per_tensor_scale = weight_tensor ._act_per_tensor_scale ,
704+                 is_swizzled_scales = k .is_swizzled_scales ,
705+                 use_triton_kernel = k .use_triton_kernel ,
678706            )
679707        return  _addmm_nvfp4_dispatch (input_tensor , weight_tensor , func , bias = bias )
680708
0 commit comments