@@ -52,6 +52,101 @@ def _to_fp8(x, scale):
5252 return (x / scale ).clamp (FP8_MIN , FP8_MAX ).to (torch .float8_e4m3fn )
5353
5454
55+ @torch .library .custom_op ("auto_deploy::trtllm_quant_fp8_linear" , mutates_args = ())
56+ def trtllm_quant_fp8_linear (
57+ input : torch .Tensor ,
58+ weight_fp8 : torch .Tensor ,
59+ bias : Optional [torch .Tensor ] = None ,
60+ input_scale : Optional [torch .Tensor ] = None ,
61+ weight_scale : Optional [torch .Tensor ] = None ,
62+ ) -> torch .Tensor :
63+ """FP8 linear op similar to torch.nn.linear using TensorRT-LLM FP8 operations.
64+
65+ Args:
66+ input: unquantized input tensor
67+ weight_fp8: pre-quantized weight tensor, with dtype torch.float8_e4m3fn
68+ input_scale: (Optional) pre-computed scalar tensor for static quantization.
69+ weight_scale: scalar tensor for weight dequantization.
70+
71+ Returns:
72+ The linear output with the original dtype as the input.
73+ """
74+ input_shape = input .shape
75+ input_dtype = input .dtype
76+
77+ n = weight_fp8 .shape [0 ] # out_features
78+ k = weight_fp8 .shape [1 ] # in_features
79+
80+ # Verify dimensions match
81+ assert input_shape [- 1 ] == k , f"Input last dim { input_shape [- 1 ]} must match weight last dim { k } "
82+
83+ input = input .reshape (- 1 , k )
84+
85+ # Calculate padding needed to reach next multiple of 16
86+ k_pad = (16 - k % 16 ) % 16 # Amount to pad K dimension
87+ n_pad = (16 - n % 16 ) % 16 # Amount to pad N dimension
88+
89+ if k_pad != 0 :
90+ # Pad input on the last dimension (K dimension)
91+ input = torch .nn .functional .pad (input , (0 , k_pad ), mode = "constant" , value = 0 ).contiguous ()
92+ # Pad weight on the last dimension (K dimension)
93+ weight_fp8 = torch .nn .functional .pad (
94+ weight_fp8 , (0 , k_pad ), mode = "constant" , value = 0
95+ ).contiguous ()
96+
97+ if n_pad != 0 :
98+ # Pad weight on the first dimension (N dimension)
99+ weight_fp8 = torch .nn .functional .pad (
100+ weight_fp8 , (0 , 0 , 0 , n_pad ), mode = "constant" , value = 0
101+ ).contiguous ()
102+
103+ # Use TensorRT-LLM FP8 per-tensor quantization
104+ assert input_scale is not None
105+ input_fp8 , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (input , input_scale )
106+
107+ # Use TensorRT-LLM FP8 scaled matrix multiply
108+ # Choose between CUDA core (for small M) and cuBLAS (for large M) implementations
109+ if input_fp8 .shape [0 ] <= 8 : # NOTE: this kernel work with n % 2 == 0 as well??
110+ # Use CUDA core for small M dimension (better for small batch sizes)
111+ output = torch .ops .trtllm .cuda_scaled_mm (
112+ input_fp8 ,
113+ weight_fp8 .t (),
114+ scale_a = input_scale ,
115+ scale_b = weight_scale ,
116+ bias = None ,
117+ out_dtype = input_dtype ,
118+ )
119+ else :
120+ # Use cuBLAS for large M dimension
121+ output = torch .ops .trtllm .cublas_scaled_mm (
122+ input_fp8 ,
123+ weight_fp8 .t (),
124+ scale_a = input_scale ,
125+ scale_b = weight_scale ,
126+ bias = None ,
127+ out_dtype = input_dtype ,
128+ )
129+
130+ # Remove padding from output if needed
131+ if n_pad != 0 :
132+ output = output [..., :n ]
133+
134+ if bias is not None :
135+ output = output + bias
136+ return output .reshape (* input_shape [:- 1 ], n )
137+
138+
139+ @trtllm_quant_fp8_linear .register_fake
140+ def trtllm_quant_fp8_linear_fake (
141+ input : torch .Tensor ,
142+ weight_fp8 : torch .Tensor ,
143+ bias : Optional [torch .Tensor ] = None ,
144+ input_scale : Optional [torch .Tensor ] = None ,
145+ weight_scale : Optional [torch .Tensor ] = None ,
146+ ) -> torch .Tensor :
147+ return torch .ops .aten .linear (input , weight_fp8 .to (input .dtype ), bias )
148+
149+
55150@torch .library .custom_op ("auto_deploy::torch_quant_fp8_linear" , mutates_args = ())
56151@torch .compile (dynamic = True )
57152def fp8_linear (
@@ -72,27 +167,59 @@ def fp8_linear(
72167 Returns:
73168 The linear output with the original dtype as the input.
74169 """
75- assert input .shape [- 1 ] % 16 == 0
76- assert weight_fp8 .shape [- 1 ] % 16 == 0
77-
78170 input_shape = input .shape
79171 weight_shape = weight_fp8 .shape
80172
173+ # Original dimensions
174+ n = weight_shape [0 ] # out_features
175+ k = weight_shape [1 ] # in_features
176+
177+ # Verify dimensions match
178+ assert input_shape [- 1 ] == k , f"Input last dim { input_shape [- 1 ]} must match weight last dim { k } "
179+
180+ # Calculate padding needed to reach next multiple of 16
181+ k_pad = (16 - k % 16 ) % 16 # Amount to pad K dimension
182+ n_pad = (16 - n % 16 ) % 16 # Amount to pad N dimension
183+
184+ if k_pad != 0 :
185+ # Pad input on the last dimension (K dimension)
186+ input = torch .nn .functional .pad (input , (0 , k_pad ), mode = "constant" , value = 0 ).contiguous ()
187+ # Pad weight on the last dimension (K dimension)
188+ weight_fp8 = torch .nn .functional .pad (
189+ weight_fp8 , (0 , k_pad ), mode = "constant" , value = 0
190+ ).contiguous ()
191+
192+ if n_pad != 0 :
193+ # Pad weight on the first dimension (N dimension)
194+ weight_fp8 = torch .nn .functional .pad (
195+ weight_fp8 , (0 , 0 , 0 , n_pad ), mode = "constant" , value = 0
196+ ).contiguous ()
197+
81198 # Cuda graph compatibility
82199 assert input_scale is not None
83200 input_fp8 = _to_fp8 (input , input_scale )
84201
85- weight_fp8_t = weight_fp8 .reshape (- 1 , weight_shape [- 1 ]).t ()
202+ weight_fp8_t = weight_fp8 .reshape (- 1 , weight_fp8 .shape [- 1 ]).t ()
203+
204+ # If we have N padding, don't add bias in addmm (it won't match dimensions)
205+ # We'll add it after removing padding
86206 output = addmm_float8_unwrapped (
87- input_fp8 .reshape (- 1 , input_shape [- 1 ]),
207+ input_fp8 .reshape (- 1 , input . shape [- 1 ]),
88208 input_scale ,
89209 weight_fp8_t ,
90210 weight_scale ,
91211 input .dtype ,
92- bias = bias ,
212+ bias = None if n_pad != 0 else bias ,
93213 use_fast_accum = True ,
94214 )
95215
216+ # Remove padding from output if needed
217+ if n_pad != 0 :
218+ output = output [..., :n ]
219+ # Add bias after removing padding
220+ if bias is not None :
221+ output = output + bias
222+
96223 return output .reshape (* input_shape [:- 1 ], output .shape [- 1 ])
97224
98225
0 commit comments