Skip to content

Commit 7e96271

Browse files
committed
replace fp8 kernels NVIDIA#8820
1 parent 7861380 commit 7e96271

File tree

4 files changed

+195
-51
lines changed

4 files changed

+195
-51
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ transforms:
109109
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
110110
fuse_fp8_linear:
111111
stage: post_load_fusion
112-
backend: torch
112+
backend: trtllm
113113
fuse_nvfp4_linear:
114114
stage: post_load_fusion
115115
backend: trtllm

tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py

Lines changed: 133 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,101 @@ def _to_fp8(x, scale):
5252
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
5353

5454

55+
@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_linear", mutates_args=())
56+
def trtllm_quant_fp8_linear(
57+
input: torch.Tensor,
58+
weight_fp8: torch.Tensor,
59+
bias: Optional[torch.Tensor] = None,
60+
input_scale: Optional[torch.Tensor] = None,
61+
weight_scale: Optional[torch.Tensor] = None,
62+
) -> torch.Tensor:
63+
"""FP8 linear op similar to torch.nn.linear using TensorRT-LLM FP8 operations.
64+
65+
Args:
66+
input: unquantized input tensor
67+
weight_fp8: pre-quantized weight tensor, with dtype torch.float8_e4m3fn
68+
input_scale: (Optional) pre-computed scalar tensor for static quantization.
69+
weight_scale: scalar tensor for weight dequantization.
70+
71+
Returns:
72+
The linear output with the original dtype as the input.
73+
"""
74+
input_shape = input.shape
75+
input_dtype = input.dtype
76+
77+
n = weight_fp8.shape[0] # out_features
78+
k = weight_fp8.shape[1] # in_features
79+
80+
# Verify dimensions match
81+
assert input_shape[-1] == k, f"Input last dim {input_shape[-1]} must match weight last dim {k}"
82+
83+
input = input.reshape(-1, k)
84+
85+
# Calculate padding needed to reach next multiple of 16
86+
k_pad = (16 - k % 16) % 16 # Amount to pad K dimension
87+
n_pad = (16 - n % 16) % 16 # Amount to pad N dimension
88+
89+
if k_pad != 0:
90+
# Pad input on the last dimension (K dimension)
91+
input = torch.nn.functional.pad(input, (0, k_pad), mode="constant", value=0).contiguous()
92+
# Pad weight on the last dimension (K dimension)
93+
weight_fp8 = torch.nn.functional.pad(
94+
weight_fp8, (0, k_pad), mode="constant", value=0
95+
).contiguous()
96+
97+
if n_pad != 0:
98+
# Pad weight on the first dimension (N dimension)
99+
weight_fp8 = torch.nn.functional.pad(
100+
weight_fp8, (0, 0, 0, n_pad), mode="constant", value=0
101+
).contiguous()
102+
103+
# Use TensorRT-LLM FP8 per-tensor quantization
104+
assert input_scale is not None
105+
input_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(input, input_scale)
106+
107+
# Use TensorRT-LLM FP8 scaled matrix multiply
108+
# Choose between CUDA core (for small M) and cuBLAS (for large M) implementations
109+
if input_fp8.shape[0] <= 8: # NOTE: this kernel work with n % 2 == 0 as well??
110+
# Use CUDA core for small M dimension (better for small batch sizes)
111+
output = torch.ops.trtllm.cuda_scaled_mm(
112+
input_fp8,
113+
weight_fp8.t(),
114+
scale_a=input_scale,
115+
scale_b=weight_scale,
116+
bias=None,
117+
out_dtype=input_dtype,
118+
)
119+
else:
120+
# Use cuBLAS for large M dimension
121+
output = torch.ops.trtllm.cublas_scaled_mm(
122+
input_fp8,
123+
weight_fp8.t(),
124+
scale_a=input_scale,
125+
scale_b=weight_scale,
126+
bias=None,
127+
out_dtype=input_dtype,
128+
)
129+
130+
# Remove padding from output if needed
131+
if n_pad != 0:
132+
output = output[..., :n]
133+
134+
if bias is not None:
135+
output = output + bias
136+
return output.reshape(*input_shape[:-1], n)
137+
138+
139+
@trtllm_quant_fp8_linear.register_fake
140+
def trtllm_quant_fp8_linear_fake(
141+
input: torch.Tensor,
142+
weight_fp8: torch.Tensor,
143+
bias: Optional[torch.Tensor] = None,
144+
input_scale: Optional[torch.Tensor] = None,
145+
weight_scale: Optional[torch.Tensor] = None,
146+
) -> torch.Tensor:
147+
return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias)
148+
149+
55150
@torch.library.custom_op("auto_deploy::torch_quant_fp8_linear", mutates_args=())
56151
@torch.compile(dynamic=True)
57152
def fp8_linear(
@@ -72,27 +167,59 @@ def fp8_linear(
72167
Returns:
73168
The linear output with the original dtype as the input.
74169
"""
75-
assert input.shape[-1] % 16 == 0
76-
assert weight_fp8.shape[-1] % 16 == 0
77-
78170
input_shape = input.shape
79171
weight_shape = weight_fp8.shape
80172

173+
# Original dimensions
174+
n = weight_shape[0] # out_features
175+
k = weight_shape[1] # in_features
176+
177+
# Verify dimensions match
178+
assert input_shape[-1] == k, f"Input last dim {input_shape[-1]} must match weight last dim {k}"
179+
180+
# Calculate padding needed to reach next multiple of 16
181+
k_pad = (16 - k % 16) % 16 # Amount to pad K dimension
182+
n_pad = (16 - n % 16) % 16 # Amount to pad N dimension
183+
184+
if k_pad != 0:
185+
# Pad input on the last dimension (K dimension)
186+
input = torch.nn.functional.pad(input, (0, k_pad), mode="constant", value=0).contiguous()
187+
# Pad weight on the last dimension (K dimension)
188+
weight_fp8 = torch.nn.functional.pad(
189+
weight_fp8, (0, k_pad), mode="constant", value=0
190+
).contiguous()
191+
192+
if n_pad != 0:
193+
# Pad weight on the first dimension (N dimension)
194+
weight_fp8 = torch.nn.functional.pad(
195+
weight_fp8, (0, 0, 0, n_pad), mode="constant", value=0
196+
).contiguous()
197+
81198
# Cuda graph compatibility
82199
assert input_scale is not None
83200
input_fp8 = _to_fp8(input, input_scale)
84201

85-
weight_fp8_t = weight_fp8.reshape(-1, weight_shape[-1]).t()
202+
weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t()
203+
204+
# If we have N padding, don't add bias in addmm (it won't match dimensions)
205+
# We'll add it after removing padding
86206
output = addmm_float8_unwrapped(
87-
input_fp8.reshape(-1, input_shape[-1]),
207+
input_fp8.reshape(-1, input.shape[-1]),
88208
input_scale,
89209
weight_fp8_t,
90210
weight_scale,
91211
input.dtype,
92-
bias=bias,
212+
bias=None if n_pad != 0 else bias,
93213
use_fast_accum=True,
94214
)
95215

216+
# Remove padding from output if needed
217+
if n_pad != 0:
218+
output = output[..., :n]
219+
# Add bias after removing padding
220+
if bias is not None:
221+
output = output + bias
222+
96223
return output.reshape(*input_shape[:-1], output.shape[-1])
97224

98225

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,6 @@ def _fp8_ref_pattern_1(
3434
)
3535

3636

37-
def _fp8_ref_repl_1(
38-
x: torch.Tensor,
39-
w_fp8: torch.Tensor,
40-
input_scale: torch.Tensor,
41-
weight_scale: torch.Tensor,
42-
):
43-
return torch.ops.auto_deploy.torch_quant_fp8_linear(
44-
x,
45-
w_fp8,
46-
None,
47-
input_scale=input_scale,
48-
weight_scale=weight_scale,
49-
)
50-
51-
5237
# with bias!=None
5338
def _fp8_ref_pattern_2(
5439
x: torch.Tensor,
@@ -68,22 +53,6 @@ def _fp8_ref_pattern_2(
6853
)
6954

7055

71-
def _fp8_ref_repl_2(
72-
x: torch.Tensor,
73-
w_fp8: torch.Tensor,
74-
bias: torch.Tensor,
75-
input_scale: torch.Tensor,
76-
weight_scale: torch.Tensor,
77-
):
78-
return torch.ops.auto_deploy.torch_quant_fp8_linear(
79-
x,
80-
w_fp8,
81-
bias,
82-
input_scale=input_scale,
83-
weight_scale=weight_scale,
84-
)
85-
86-
8756
# NVFP4: with bias=None
8857
def _fp4_ref_pattern_1(
8958
x: torch.Tensor,
@@ -158,10 +127,41 @@ def _fp4_ref_repl_2(
158127
)
159128

160129

161-
def _register_quant_fp8_linear_patterns(patterns: ADPatternMatcherPass) -> None:
130+
def _register_quant_fp8_linear_patterns(patterns: ADPatternMatcherPass, op) -> None:
162131
"""
163132
Register FP8 linear patterns with robust dummy args and minimal ignores.
164133
"""
134+
135+
# Define replacement functions that use the provided op
136+
def _fp8_ref_repl_1(
137+
x: torch.Tensor,
138+
w_fp8: torch.Tensor,
139+
input_scale: torch.Tensor,
140+
weight_scale: torch.Tensor,
141+
):
142+
return op(
143+
x,
144+
w_fp8,
145+
None,
146+
input_scale=input_scale,
147+
weight_scale=weight_scale,
148+
)
149+
150+
def _fp8_ref_repl_2(
151+
x: torch.Tensor,
152+
w_fp8: torch.Tensor,
153+
bias: torch.Tensor,
154+
input_scale: torch.Tensor,
155+
weight_scale: torch.Tensor,
156+
):
157+
return op(
158+
x,
159+
w_fp8,
160+
bias,
161+
input_scale=input_scale,
162+
weight_scale=weight_scale,
163+
)
164+
165165
# FP8 dummy tensors
166166
x_fp8 = torch.randn(3, 16, device="meta", dtype=torch.float16)
167167
w_fp8 = torch.randn(32, 16, device="meta", dtype=torch.float16)
@@ -275,11 +275,17 @@ def _apply(
275275
factory: ModelFactory,
276276
shared_config: SharedConfig,
277277
) -> Tuple[GraphModule, TransformInfo]:
278-
if self.config.backend.lower() != "torch":
278+
if self.config.backend.lower() not in ["torch", "trtllm"]:
279279
raise ValueError(f"Unsupported FP8 backend: {self.config.backend}")
280280

281281
patterns = ADPatternMatcherPass()
282-
_register_quant_fp8_linear_patterns(patterns)
282+
op = (
283+
torch.ops.auto_deploy.trtllm_quant_fp8_linear
284+
if self.config.backend.lower() == "trtllm"
285+
else torch.ops.auto_deploy.torch_quant_fp8_linear
286+
)
287+
288+
_register_quant_fp8_linear_patterns(patterns, op)
283289
cnt = patterns.apply(gm.graph)
284290

285291
info = TransformInfo(

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,39 @@
1616
INT4_BLOCK_SIZE = 128
1717

1818

19-
@pytest.mark.parametrize("bias", [torch.rand(32).to("cuda") * 10, None])
19+
@pytest.mark.parametrize("M", [3, 12]) # NOTE: ensures both kernels are called
20+
@pytest.mark.parametrize("N", [18, 28, 30, 32])
21+
@pytest.mark.parametrize("K", [16, 32])
22+
@pytest.mark.parametrize("bias", [True, False])
2023
@pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support")
21-
def test_fp8_linear(bias):
22-
input = torch.rand(3, 16).to("cuda")
23-
weight = torch.rand(32, 16).to("cuda")
24-
bias = torch.rand(32).to("cuda") * 10
24+
def test_fp8_linear(M, N, K, bias):
25+
if N % 16 != 0 or K % 16 != 0:
26+
pytest.skip("https://github.com/NVIDIA/TensorRT-LLM/issues/8811")
27+
28+
input = torch.rand(M, K, device="cuda")
29+
weight = torch.rand(N, K, device="cuda")
30+
bias = torch.rand(N).to("cuda") * 10 if bias else None
2531

2632
weight_scale = (torch.max(torch.abs(weight)) / 448).to("cuda")
2733
weight_fp8 = (weight / weight_scale).to(torch.float8_e4m3fn)
2834

29-
output_fp8_gemm = torch.ops.auto_deploy.torch_quant_fp8_linear(
35+
output_fp8_trtllm = torch.ops.auto_deploy.trtllm_quant_fp8_linear(
3036
input,
3137
weight_fp8,
3238
bias=bias,
3339
input_scale=torch.tensor(1.0).to("cuda"),
3440
weight_scale=weight_scale,
3541
)
36-
output_fp32_gemm = torch.ops.aten.linear.default(input, weight, bias=bias)
37-
38-
assert output_fp8_gemm.shape == output_fp32_gemm.shape
42+
output_fp8_torch = torch.ops.auto_deploy.torch_quant_fp8_linear(
43+
input,
44+
weight_fp8,
45+
bias=bias,
46+
input_scale=torch.tensor(1.0).to("cuda"),
47+
weight_scale=weight_scale,
48+
)
49+
assert output_fp8_trtllm.shape == output_fp8_torch.shape
3950

40-
assert torch.allclose(output_fp8_gemm, output_fp32_gemm, rtol=0.01, atol=0.15)
51+
torch.testing.assert_close(output_fp8_trtllm, output_fp8_torch, rtol=0.01, atol=0.05)
4152

4253

4354
@pytest.mark.skipif(

0 commit comments

Comments
 (0)