diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 0cb4f8889e0..ecc8928d792 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -109,7 +109,7 @@ transforms: enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs fuse_fp8_linear: stage: post_load_fusion - backend: torch + backend: trtllm fuse_nvfp4_linear: stage: post_load_fusion backend: trtllm diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py index cc4c2b6bd1f..d219abd5951 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py @@ -52,6 +52,101 @@ def _to_fp8(x, scale): return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) +@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_linear", mutates_args=()) +def trtllm_quant_fp8_linear( + input: torch.Tensor, + weight_fp8: torch.Tensor, + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """FP8 linear op similar to torch.nn.linear using TensorRT-LLM FP8 operations. + + Args: + input: unquantized input tensor + weight_fp8: pre-quantized weight tensor, with dtype torch.float8_e4m3fn + input_scale: (Optional) pre-computed scalar tensor for static quantization. + weight_scale: scalar tensor for weight dequantization. + + Returns: + The linear output with the original dtype as the input. + """ + input_shape = input.shape + input_dtype = input.dtype + + n = weight_fp8.shape[0] # out_features + k = weight_fp8.shape[1] # in_features + + # Verify dimensions match + assert input_shape[-1] == k, f"Input last dim {input_shape[-1]} must match weight last dim {k}" + + input = input.reshape(-1, k) + + # Calculate padding needed to reach next multiple of 16 + k_pad = (16 - k % 16) % 16 # Amount to pad K dimension + n_pad = (16 - n % 16) % 16 # Amount to pad N dimension + + if k_pad != 0: + # Pad input on the last dimension (K dimension) + input = torch.nn.functional.pad(input, (0, k_pad), mode="constant", value=0).contiguous() + # Pad weight on the last dimension (K dimension) + weight_fp8 = torch.nn.functional.pad( + weight_fp8, (0, k_pad), mode="constant", value=0 + ).contiguous() + + if n_pad != 0: + # Pad weight on the first dimension (N dimension) + weight_fp8 = torch.nn.functional.pad( + weight_fp8, (0, 0, 0, n_pad), mode="constant", value=0 + ).contiguous() + + # Use TensorRT-LLM FP8 per-tensor quantization + assert input_scale is not None + input_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(input, input_scale) + + # Use TensorRT-LLM FP8 scaled matrix multiply + # Choose between CUDA core (for small M) and cuBLAS (for large M) implementations + if input_fp8.shape[0] <= 8: # NOTE: this kernel work with n % 2 == 0 as well?? + # Use CUDA core for small M dimension (better for small batch sizes) + output = torch.ops.trtllm.cuda_scaled_mm( + input_fp8, + weight_fp8.t(), + scale_a=input_scale, + scale_b=weight_scale, + bias=None, + out_dtype=input_dtype, + ) + else: + # Use cuBLAS for large M dimension + output = torch.ops.trtllm.cublas_scaled_mm( + input_fp8, + weight_fp8.t(), + scale_a=input_scale, + scale_b=weight_scale, + bias=None, + out_dtype=input_dtype, + ) + + # Remove padding from output if needed + if n_pad != 0: + output = output[..., :n] + + if bias is not None: + output = output + bias + return output.reshape(*input_shape[:-1], n) + + +@trtllm_quant_fp8_linear.register_fake +def trtllm_quant_fp8_linear_fake( + input: torch.Tensor, + weight_fp8: torch.Tensor, + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias) + + @torch.library.custom_op("auto_deploy::torch_quant_fp8_linear", mutates_args=()) @torch.compile(dynamic=True) def fp8_linear( @@ -72,27 +167,59 @@ def fp8_linear( Returns: The linear output with the original dtype as the input. """ - assert input.shape[-1] % 16 == 0 - assert weight_fp8.shape[-1] % 16 == 0 - input_shape = input.shape weight_shape = weight_fp8.shape + # Original dimensions + n = weight_shape[0] # out_features + k = weight_shape[1] # in_features + + # Verify dimensions match + assert input_shape[-1] == k, f"Input last dim {input_shape[-1]} must match weight last dim {k}" + + # Calculate padding needed to reach next multiple of 16 + k_pad = (16 - k % 16) % 16 # Amount to pad K dimension + n_pad = (16 - n % 16) % 16 # Amount to pad N dimension + + if k_pad != 0: + # Pad input on the last dimension (K dimension) + input = torch.nn.functional.pad(input, (0, k_pad), mode="constant", value=0).contiguous() + # Pad weight on the last dimension (K dimension) + weight_fp8 = torch.nn.functional.pad( + weight_fp8, (0, k_pad), mode="constant", value=0 + ).contiguous() + + if n_pad != 0: + # Pad weight on the first dimension (N dimension) + weight_fp8 = torch.nn.functional.pad( + weight_fp8, (0, 0, 0, n_pad), mode="constant", value=0 + ).contiguous() + # Cuda graph compatibility assert input_scale is not None input_fp8 = _to_fp8(input, input_scale) - weight_fp8_t = weight_fp8.reshape(-1, weight_shape[-1]).t() + weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t() + + # If we have N padding, don't add bias in addmm (it won't match dimensions) + # We'll add it after removing padding output = addmm_float8_unwrapped( - input_fp8.reshape(-1, input_shape[-1]), + input_fp8.reshape(-1, input.shape[-1]), input_scale, weight_fp8_t, weight_scale, input.dtype, - bias=bias, + bias=None if n_pad != 0 else bias, use_fast_accum=True, ) + # Remove padding from output if needed + if n_pad != 0: + output = output[..., :n] + # Add bias after removing padding + if bias is not None: + output = output + bias + return output.reshape(*input_shape[:-1], output.shape[-1]) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py index 3380442061c..5fb3ba0966f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py @@ -34,21 +34,6 @@ def _fp8_ref_pattern_1( ) -def _fp8_ref_repl_1( - x: torch.Tensor, - w_fp8: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, -): - return torch.ops.auto_deploy.torch_quant_fp8_linear( - x, - w_fp8, - None, - input_scale=input_scale, - weight_scale=weight_scale, - ) - - # with bias!=None def _fp8_ref_pattern_2( x: torch.Tensor, @@ -68,22 +53,6 @@ def _fp8_ref_pattern_2( ) -def _fp8_ref_repl_2( - x: torch.Tensor, - w_fp8: torch.Tensor, - bias: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, -): - return torch.ops.auto_deploy.torch_quant_fp8_linear( - x, - w_fp8, - bias, - input_scale=input_scale, - weight_scale=weight_scale, - ) - - # NVFP4: with bias=None def _fp4_ref_pattern_1( x: torch.Tensor, @@ -158,10 +127,41 @@ def _fp4_ref_repl_2( ) -def _register_quant_fp8_linear_patterns(patterns: ADPatternMatcherPass) -> None: +def _register_quant_fp8_linear_patterns(patterns: ADPatternMatcherPass, op) -> None: """ Register FP8 linear patterns with robust dummy args and minimal ignores. """ + + # Define replacement functions that use the provided op + def _fp8_ref_repl_1( + x: torch.Tensor, + w_fp8: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + ): + return op( + x, + w_fp8, + None, + input_scale=input_scale, + weight_scale=weight_scale, + ) + + def _fp8_ref_repl_2( + x: torch.Tensor, + w_fp8: torch.Tensor, + bias: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + ): + return op( + x, + w_fp8, + bias, + input_scale=input_scale, + weight_scale=weight_scale, + ) + # FP8 dummy tensors x_fp8 = torch.randn(3, 16, device="meta", dtype=torch.float16) w_fp8 = torch.randn(32, 16, device="meta", dtype=torch.float16) @@ -275,11 +275,17 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - if self.config.backend.lower() != "torch": + if self.config.backend.lower() not in ["torch", "trtllm"]: raise ValueError(f"Unsupported FP8 backend: {self.config.backend}") patterns = ADPatternMatcherPass() - _register_quant_fp8_linear_patterns(patterns) + op = ( + torch.ops.auto_deploy.trtllm_quant_fp8_linear + if self.config.backend.lower() == "trtllm" + else torch.ops.auto_deploy.torch_quant_fp8_linear + ) + + _register_quant_fp8_linear_patterns(patterns, op) cnt = patterns.apply(gm.graph) info = TransformInfo( diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py index d28aba51134..bfe8d75f1a4 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py @@ -16,28 +16,39 @@ INT4_BLOCK_SIZE = 128 -@pytest.mark.parametrize("bias", [torch.rand(32).to("cuda") * 10, None]) +@pytest.mark.parametrize("M", [3, 12]) # NOTE: ensures both kernels are called +@pytest.mark.parametrize("N", [18, 28, 30, 32]) +@pytest.mark.parametrize("K", [16, 32]) +@pytest.mark.parametrize("bias", [True, False]) @pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support") -def test_fp8_linear(bias): - input = torch.rand(3, 16).to("cuda") - weight = torch.rand(32, 16).to("cuda") - bias = torch.rand(32).to("cuda") * 10 +def test_fp8_linear(M, N, K, bias): + if N % 16 != 0 or K % 16 != 0: + pytest.skip("https://github.com/NVIDIA/TensorRT-LLM/issues/8811") + + input = torch.rand(M, K, device="cuda") + weight = torch.rand(N, K, device="cuda") + bias = torch.rand(N).to("cuda") * 10 if bias else None weight_scale = (torch.max(torch.abs(weight)) / 448).to("cuda") weight_fp8 = (weight / weight_scale).to(torch.float8_e4m3fn) - output_fp8_gemm = torch.ops.auto_deploy.torch_quant_fp8_linear( + output_fp8_trtllm = torch.ops.auto_deploy.trtllm_quant_fp8_linear( input, weight_fp8, bias=bias, input_scale=torch.tensor(1.0).to("cuda"), weight_scale=weight_scale, ) - output_fp32_gemm = torch.ops.aten.linear.default(input, weight, bias=bias) - - assert output_fp8_gemm.shape == output_fp32_gemm.shape + output_fp8_torch = torch.ops.auto_deploy.torch_quant_fp8_linear( + input, + weight_fp8, + bias=bias, + input_scale=torch.tensor(1.0).to("cuda"), + weight_scale=weight_scale, + ) + assert output_fp8_trtllm.shape == output_fp8_torch.shape - assert torch.allclose(output_fp8_gemm, output_fp32_gemm, rtol=0.01, atol=0.15) + torch.testing.assert_close(output_fp8_trtllm, output_fp8_torch, rtol=0.01, atol=0.05) @pytest.mark.skipif(