Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ transforms:
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
fuse_fp8_linear:
stage: post_load_fusion
backend: torch
backend: trtllm
fuse_nvfp4_linear:
stage: post_load_fusion
backend: trtllm
Expand Down
139 changes: 133 additions & 6 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,101 @@ def _to_fp8(x, scale):
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)


@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_linear", mutates_args=())
def trtllm_quant_fp8_linear(
input: torch.Tensor,
weight_fp8: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""FP8 linear op similar to torch.nn.linear using TensorRT-LLM FP8 operations.

Args:
input: unquantized input tensor
weight_fp8: pre-quantized weight tensor, with dtype torch.float8_e4m3fn
input_scale: (Optional) pre-computed scalar tensor for static quantization.
weight_scale: scalar tensor for weight dequantization.

Returns:
The linear output with the original dtype as the input.
"""
input_shape = input.shape
input_dtype = input.dtype

n = weight_fp8.shape[0] # out_features
k = weight_fp8.shape[1] # in_features

# Verify dimensions match
assert input_shape[-1] == k, f"Input last dim {input_shape[-1]} must match weight last dim {k}"

input = input.reshape(-1, k)

# Calculate padding needed to reach next multiple of 16
k_pad = (16 - k % 16) % 16 # Amount to pad K dimension
n_pad = (16 - n % 16) % 16 # Amount to pad N dimension

if k_pad != 0:
# Pad input on the last dimension (K dimension)
input = torch.nn.functional.pad(input, (0, k_pad), mode="constant", value=0).contiguous()
# Pad weight on the last dimension (K dimension)
weight_fp8 = torch.nn.functional.pad(
weight_fp8, (0, k_pad), mode="constant", value=0
).contiguous()

if n_pad != 0:
# Pad weight on the first dimension (N dimension)
weight_fp8 = torch.nn.functional.pad(
weight_fp8, (0, 0, 0, n_pad), mode="constant", value=0
).contiguous()

# Use TensorRT-LLM FP8 per-tensor quantization
assert input_scale is not None
input_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(input, input_scale)

# Use TensorRT-LLM FP8 scaled matrix multiply
# Choose between CUDA core (for small M) and cuBLAS (for large M) implementations
if input_fp8.shape[0] <= 8: # NOTE: this kernel work with n % 2 == 0 as well??
# Use CUDA core for small M dimension (better for small batch sizes)
output = torch.ops.trtllm.cuda_scaled_mm(
input_fp8,
weight_fp8.t(),
scale_a=input_scale,
scale_b=weight_scale,
bias=None,
out_dtype=input_dtype,
)
else:
# Use cuBLAS for large M dimension
output = torch.ops.trtllm.cublas_scaled_mm(
input_fp8,
weight_fp8.t(),
scale_a=input_scale,
scale_b=weight_scale,
bias=None,
out_dtype=input_dtype,
)

# Remove padding from output if needed
if n_pad != 0:
output = output[..., :n]

if bias is not None:
output = output + bias
return output.reshape(*input_shape[:-1], n)


@trtllm_quant_fp8_linear.register_fake
def trtllm_quant_fp8_linear_fake(
input: torch.Tensor,
weight_fp8: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias)


@torch.library.custom_op("auto_deploy::torch_quant_fp8_linear", mutates_args=())
@torch.compile(dynamic=True)
def fp8_linear(
Expand All @@ -72,27 +167,59 @@ def fp8_linear(
Returns:
The linear output with the original dtype as the input.
"""
assert input.shape[-1] % 16 == 0
assert weight_fp8.shape[-1] % 16 == 0

input_shape = input.shape
weight_shape = weight_fp8.shape

# Original dimensions
n = weight_shape[0] # out_features
k = weight_shape[1] # in_features

# Verify dimensions match
assert input_shape[-1] == k, f"Input last dim {input_shape[-1]} must match weight last dim {k}"

# Calculate padding needed to reach next multiple of 16
k_pad = (16 - k % 16) % 16 # Amount to pad K dimension
n_pad = (16 - n % 16) % 16 # Amount to pad N dimension

if k_pad != 0:
# Pad input on the last dimension (K dimension)
input = torch.nn.functional.pad(input, (0, k_pad), mode="constant", value=0).contiguous()
# Pad weight on the last dimension (K dimension)
weight_fp8 = torch.nn.functional.pad(
weight_fp8, (0, k_pad), mode="constant", value=0
).contiguous()

if n_pad != 0:
# Pad weight on the first dimension (N dimension)
weight_fp8 = torch.nn.functional.pad(
weight_fp8, (0, 0, 0, n_pad), mode="constant", value=0
).contiguous()

# Cuda graph compatibility
assert input_scale is not None
input_fp8 = _to_fp8(input, input_scale)

weight_fp8_t = weight_fp8.reshape(-1, weight_shape[-1]).t()
weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t()

# If we have N padding, don't add bias in addmm (it won't match dimensions)
# We'll add it after removing padding
output = addmm_float8_unwrapped(
input_fp8.reshape(-1, input_shape[-1]),
input_fp8.reshape(-1, input.shape[-1]),
input_scale,
weight_fp8_t,
weight_scale,
input.dtype,
bias=bias,
bias=None if n_pad != 0 else bias,
use_fast_accum=True,
)

# Remove padding from output if needed
if n_pad != 0:
output = output[..., :n]
# Add bias after removing padding
if bias is not None:
output = output + bias

return output.reshape(*input_shape[:-1], output.shape[-1])


Expand Down
74 changes: 40 additions & 34 deletions tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,6 @@ def _fp8_ref_pattern_1(
)


def _fp8_ref_repl_1(
x: torch.Tensor,
w_fp8: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
):
return torch.ops.auto_deploy.torch_quant_fp8_linear(
x,
w_fp8,
None,
input_scale=input_scale,
weight_scale=weight_scale,
)


# with bias!=None
def _fp8_ref_pattern_2(
x: torch.Tensor,
Expand All @@ -68,22 +53,6 @@ def _fp8_ref_pattern_2(
)


def _fp8_ref_repl_2(
x: torch.Tensor,
w_fp8: torch.Tensor,
bias: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
):
return torch.ops.auto_deploy.torch_quant_fp8_linear(
x,
w_fp8,
bias,
input_scale=input_scale,
weight_scale=weight_scale,
)


# NVFP4: with bias=None
def _fp4_ref_pattern_1(
x: torch.Tensor,
Expand Down Expand Up @@ -158,10 +127,41 @@ def _fp4_ref_repl_2(
)


def _register_quant_fp8_linear_patterns(patterns: ADPatternMatcherPass) -> None:
def _register_quant_fp8_linear_patterns(patterns: ADPatternMatcherPass, op) -> None:
"""
Register FP8 linear patterns with robust dummy args and minimal ignores.
"""

# Define replacement functions that use the provided op
def _fp8_ref_repl_1(
x: torch.Tensor,
w_fp8: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
):
return op(
x,
w_fp8,
None,
input_scale=input_scale,
weight_scale=weight_scale,
)

def _fp8_ref_repl_2(
x: torch.Tensor,
w_fp8: torch.Tensor,
bias: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
):
return op(
x,
w_fp8,
bias,
input_scale=input_scale,
weight_scale=weight_scale,
)

# FP8 dummy tensors
x_fp8 = torch.randn(3, 16, device="meta", dtype=torch.float16)
w_fp8 = torch.randn(32, 16, device="meta", dtype=torch.float16)
Expand Down Expand Up @@ -275,11 +275,17 @@ def _apply(
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
if self.config.backend.lower() != "torch":
if self.config.backend.lower() not in ["torch", "trtllm"]:
raise ValueError(f"Unsupported FP8 backend: {self.config.backend}")

patterns = ADPatternMatcherPass()
_register_quant_fp8_linear_patterns(patterns)
op = (
torch.ops.auto_deploy.trtllm_quant_fp8_linear
if self.config.backend.lower() == "trtllm"
else torch.ops.auto_deploy.torch_quant_fp8_linear
)

_register_quant_fp8_linear_patterns(patterns, op)
cnt = patterns.apply(gm.graph)

info = TransformInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,39 @@
INT4_BLOCK_SIZE = 128


@pytest.mark.parametrize("bias", [torch.rand(32).to("cuda") * 10, None])
@pytest.mark.parametrize("M", [3, 12]) # NOTE: ensures both kernels are called
@pytest.mark.parametrize("N", [18, 28, 30, 32])
@pytest.mark.parametrize("K", [16, 32])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support")
def test_fp8_linear(bias):
input = torch.rand(3, 16).to("cuda")
weight = torch.rand(32, 16).to("cuda")
bias = torch.rand(32).to("cuda") * 10
def test_fp8_linear(M, N, K, bias):
if N % 16 != 0 or K % 16 != 0:
pytest.skip("https://github.com/NVIDIA/TensorRT-LLM/issues/8811")

input = torch.rand(M, K, device="cuda")
weight = torch.rand(N, K, device="cuda")
bias = torch.rand(N).to("cuda") * 10 if bias else None

weight_scale = (torch.max(torch.abs(weight)) / 448).to("cuda")
weight_fp8 = (weight / weight_scale).to(torch.float8_e4m3fn)

output_fp8_gemm = torch.ops.auto_deploy.torch_quant_fp8_linear(
output_fp8_trtllm = torch.ops.auto_deploy.trtllm_quant_fp8_linear(
input,
weight_fp8,
bias=bias,
input_scale=torch.tensor(1.0).to("cuda"),
weight_scale=weight_scale,
)
output_fp32_gemm = torch.ops.aten.linear.default(input, weight, bias=bias)

assert output_fp8_gemm.shape == output_fp32_gemm.shape
output_fp8_torch = torch.ops.auto_deploy.torch_quant_fp8_linear(
input,
weight_fp8,
bias=bias,
input_scale=torch.tensor(1.0).to("cuda"),
weight_scale=weight_scale,
)
assert output_fp8_trtllm.shape == output_fp8_torch.shape

assert torch.allclose(output_fp8_gemm, output_fp32_gemm, rtol=0.01, atol=0.15)
torch.testing.assert_close(output_fp8_trtllm, output_fp8_torch, rtol=0.01, atol=0.05)


@pytest.mark.skipif(
Expand Down