diff --git a/README.md b/README.md index 747fe0278df..a71771e7adf 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ TensorRT-LLM [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/) [![python](https://img.shields.io/badge/python-3.12-green)](https://www.python.org/downloads/release/python-3123/) [![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/) -[![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads) -[![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt) +[![cuda](https://img.shields.io/badge/cuda-13.0.0-green)](https://developer.nvidia.com/cuda-downloads) +[![trt](https://img.shields.io/badge/TRT-10.13.2-green)](https://developer.nvidia.com/tensorrt) [![version](https://img.shields.io/badge/release-1.1.0rc6-green)](./tensorrt_llm/version.py) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) diff --git a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt index cdae331b945..f690ab5a905 100644 --- a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt +++ b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt @@ -1,4 +1,4 @@ -set(DEEP_EP_COMMIT 515a311f290eb6d9592fcccfcc80c40f5123ca72) +set(DEEP_EP_COMMIT be2582ffe69b5e7d61c3bc9bf7a5316bc48261f9) set(NVSHMEM_URL_HASH SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 2e01d47eb5b..89f73af4379 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -553,6 +553,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface || std::is_same_v) &&!std::is_same_v; static constexpr bool use_w4afp8 = std::is_same_v && std::is_same_v; + static constexpr bool use_fp8_input = std::is_same_v; static_assert(!std::is_same_v, "Current logic requires backbone type to be >=16-bits"); static_assert(!std::is_same_v, "Current logic requires output type to be >=16-bits"); #else diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index ca613155a9d..38efe497f7c 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -1625,7 +1625,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, else if constexpr (std::is_same_v && std::is_same_v) { - TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ"); + TLLM_CHECK_WITH_INFO(!prequant_scales, "FP8 is not supported for AWQ"); return quant_params.mxfp8_mxfp4.fc1.weight_block_scale ? &expandInputRowsKernel @@ -3689,7 +3689,7 @@ void CutlassMoeFCRunner; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>; #endif #endif #ifdef ENABLE_FP4 diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 9b9731a076e..db1ab0e3621 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -201,6 +201,21 @@ class FusedMoeRunner : public torch::CustomClassHolder } switch (mActivationDtype) { +#ifdef ENABLE_FP8 + case c10::ScalarType::Float8_e4m3fn: + { + if (isInt4Quant() and mUseW4GroupScaling) + { + mKernelRunner = std::make_unique< + kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>>(); + } + else + { + C10_THROW_ERROR_FORMATTED(Error, "FP8 activation type is not supported for non-W4A8 quantization"); + } + break; + } +#endif case c10::ScalarType::Half: mKernelRunner = create_weight_quant_runner(); break; case c10::ScalarType::BFloat16: mKernelRunner = create_weight_quant_runner<__nv_bfloat16>(); break; default: C10_THROW_ERROR_FORMATTED(Error, "Unsupported activation type for int-type weight"); diff --git a/docs/source/developer-guide/perf-benchmarking.md b/docs/source/developer-guide/perf-benchmarking.md index 6c7dbc97c34..3881a9ab263 100644 --- a/docs/source/developer-guide/perf-benchmarking.md +++ b/docs/source/developer-guide/perf-benchmarking.md @@ -460,9 +460,10 @@ If you would like to force the KV cache quantization, you can specify the follow when the checkpoint precision is `null`: ```yaml -kv_cache_dtype: "fp8" +kv_cache_config: + dtype: fp8 ``` ```{tip} -The two valid values for `kv_cache_dtype` are `auto` and `fp8`. +The two valid values for `kv_cache_config.dtype` are `auto` and `fp8`. ``` diff --git a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py index c852bdb929c..84673163b82 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py +++ b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py @@ -2,7 +2,7 @@ # https://github.com/deepseek-ai/DeepEP/blob/aae9fa9a6dd0fec2a723fbb85ec4b22460fab670/README.md import os import weakref -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -179,12 +179,18 @@ def low_latency_dispatch_fp4(self, hidden_states: torch.Tensor, return recv_hidden_states, recv_scales, recv_expert_count, handle - def low_latency_combine_fp4(self, hidden_states: torch.Tensor, - global_scales: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, handle: Tuple): + def low_latency_combine_low_precision(self, precision: str, + hidden_states: torch.Tensor, + global_scales: Optional[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: Tuple): + """ + Arguments: + precision: the precision of the low-precision kernel, "fp8" for FP8, "nvfp4" for NVFP4. + """ combined_hidden_states, event, hook = \ - self.buffer.low_latency_combine_fp4(hidden_states, global_scales, topk_idx, topk_weights, handle) + self.buffer.low_latency_combine_low_precision(precision, hidden_states, global_scales, topk_idx, topk_weights, handle) assert event.event is None assert hook is None diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 09edb4a2e11..6b1ca3cc45e 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -20,8 +20,8 @@ from .ops import MoEOp, MoEOpSelector from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, - FP8QDQFusedMoEMethod, MoEWeightLoadingMode, - NVFP4CutlassFusedMoEMethod, + FP8QDQFusedMoEMethod, FusedMoEQuantScalesW4A8, + MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod, UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod) from .routing import BaseMoeRoutingMethod @@ -191,13 +191,10 @@ def __init__( self.use_postquant_alltoall = False self.use_low_precision_combine = False if self.enable_alltoall: - qm = self.quant_config.quant_mode self.use_postquant_alltoall = (os.environ.get( - "TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") - == "1") and qm.has_nvfp4() + "TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") == "1") self.use_low_precision_combine = (os.environ.get( - "TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0") - == "1") and qm.has_nvfp4() + "TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0") == "1") if self.alltoall_method_type == AlltoallMethodType.MNNVL: MnnvlMemory.initialize() @@ -319,6 +316,35 @@ def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens): return self.enable_alltoall + def deep_ep_low_latency_dispatch_modify_output_to_adapt_fused_moe( + self, x: torch.Tensor, x_sf: Optional[torch.Tensor], + recv_expert_count: torch.Tensor, final_scales_dtype: torch.dtype + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, + torch.Tensor]: + # x shape: [#local experts, EP size * all_rank_max_num_tokens, hidden_size] + # recv_expert_count shape: [#local experts] + + # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP + # TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API + mask = torch.arange(x.shape[1], + dtype=torch.int32, device=x.device).expand( + x.shape[0], + x.shape[1]) < recv_expert_count.unsqueeze(1) + token_selected_slots = torch.where( + mask, + torch.arange(x.shape[0] * self.mapping.moe_ep_rank, + x.shape[0] * (self.mapping.moe_ep_rank + 1), + dtype=torch.int32, + device=x.device).unsqueeze(1), self.num_slots) + x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) + if x_sf is not None: + x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1], x_sf.shape[2]) + # Cheat the fused_moe API with fake top_k=1 + token_selected_slots = token_selected_slots.view(x.shape[0], 1) + token_final_scales = torch.ones_like(token_selected_slots, + dtype=final_scales_dtype) + return x, x_sf, token_selected_slots, token_final_scales + def _get_quant_method(self): if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( exclude_kv_cache=True): @@ -468,7 +494,7 @@ def forward_chunk( use_allgather = not use_all_to_all # If alltoall is disabled, we need also disable use_postquant_alltoall - use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all + use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all and self.has_any_quant # Prepare additional information for profiling in case padding is applied when using alltoall. # Only the non-alltoall case is considered for profiling in the warmup phase. @@ -518,28 +544,8 @@ def forward_chunk( assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens x, recv_expert_count, deep_ep_handle = \ self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots) - # x shape: [#local experts, EP size * all_rank_max_num_tokens, hidden_size] - # recv_expert_count shape: [#local experts] - - # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP - # TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API - mask = torch.arange( - x.shape[1], dtype=torch.int32, device=x.device).expand( - x.shape[0], - x.shape[1]) < recv_expert_count.unsqueeze(1) - token_selected_slots = torch.where( - mask, - torch.arange( - x.shape[0] * self.mapping.moe_ep_rank, - x.shape[0] * (self.mapping.moe_ep_rank + 1), - dtype=torch.int32, - device=x.device).unsqueeze(1), self.num_slots) - x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) - # Cheat the fused_moe API with fake top_k=1 - token_selected_slots = token_selected_slots.view( - x.shape[0], 1) - token_final_scales = torch.ones_like( - token_selected_slots, dtype=token_final_scales.dtype) + x, _, token_selected_slots, token_final_scales = self.deep_ep_low_latency_dispatch_modify_output_to_adapt_fused_moe( + x, None, recv_expert_count, token_final_scales.dtype) x_sf = None x_row = x.shape[0] @@ -621,41 +627,48 @@ def forward_chunk( if x_sf is not None: x_sf = x_sf.view(x_sf_dtype) elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: - token_num = x_row - hidden_size = x_col - assert x_sf is not None and self.has_nvfp4 - assert hidden_size % 32 == 0 - assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8 - assert x_sf.shape[0] == token_num and x_sf.shape[ - 1] == hidden_size // 16 - assert x.shape[0] == token_num and x.shape[1] == hidden_size // 2 - + assert self.has_any_quant, "DeepEPLowLatency postquant alltoall should have quantization" + assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens deep_ep_topk_idx = token_selected_slots deep_ep_topk_weights = token_final_scales - - assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens - x, x_sf, recv_expert_count, deep_ep_handle = \ - self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots) - assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8 - assert x.dim() == 3 and x_sf.dim() == 3 - assert x.shape[2] == hidden_size // 2 and x_sf.shape[ - 2] == hidden_size // 16 - - mask = torch.arange( - x.shape[1], dtype=torch.int32, device=x.device).expand( - x.shape[0], x.shape[1]) < recv_expert_count.unsqueeze(1) - token_selected_slots = torch.where( - mask, - torch.arange(x.shape[0] * self.mapping.moe_ep_rank, - x.shape[0] * (self.mapping.moe_ep_rank + 1), - dtype=torch.int32, - device=x.device).unsqueeze(1), self.num_slots) - x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) - x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1], - x_sf.shape[2]) - token_selected_slots = token_selected_slots.view(x.shape[0], 1) - token_final_scales = torch.ones_like( - token_selected_slots, dtype=token_final_scales.dtype) + if self.has_fp8_qdq: + assert x.dtype == torch.float8_e4m3fn and x_sf is None, "x should be torch.float8_e4m3fn and x_sf should be None in fp8 postquant alltoall" + x = x.view(torch.bfloat16) + x, recv_expert_count, deep_ep_handle = \ + self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots) + x = x.view(torch.float8_e4m3fn) + elif self.has_nvfp4: + token_num = x_row + hidden_size = x_col + assert x.dtype == torch.uint8 and x_sf is not None and x_sf.dtype == torch.uint8 + assert hidden_size % 32 == 0, "HiddenSize should be divisible by 32 in nvfp4 postquant alltoall" + assert x_sf.shape[0] == token_num and x_sf.shape[ + 1] == hidden_size // 16 + assert x.shape[0] == token_num and x.shape[ + 1] == hidden_size // 2 + + x, x_sf, recv_expert_count, deep_ep_handle = \ + self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots) + assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8 + assert x.dim() == 3 and x_sf.dim() == 3 + assert x.shape[2] == hidden_size // 2 and x_sf.shape[ + 2] == hidden_size // 16 + elif self.has_w4afp8: + assert isinstance(quant_scales, FusedMoEQuantScalesW4A8) + pre_quant_scales = quant_scales.pre_quant_scale_1 + assert pre_quant_scales.shape == ( + 1, x.shape[1]) and pre_quant_scales.dtype == x.dtype + x = (x * pre_quant_scales).to(torch.float8_e4m3fn).view( + torch.bfloat16) + x, recv_expert_count, deep_ep_handle = \ + self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots) + x = x.view(torch.float8_e4m3fn) + else: + raise ValueError( + f"unsupported quantization mode in postquant alltoall: {self.quant_config.quant_mode}" + ) + x, x_sf, token_selected_slots, token_final_scales = self.deep_ep_low_latency_dispatch_modify_output_to_adapt_fused_moe( + x, x_sf, recv_expert_count, token_final_scales.dtype) else: raise NotImplementedError( f"Not available alltoall method type: {self.alltoall_method_type!r}" @@ -704,11 +717,16 @@ def forward_chunk( self.expert_size_per_partition, num_tokens_per_expert_for_fused_moe, self.hidden_size) if self.use_low_precision_combine: - global_scales = torch.ops.trtllm.calculate_nvfp4_global_scale( - final_hidden_states, recv_expert_count) - final_hidden_states = self.deep_ep_buffer.low_latency_combine_fp4( - final_hidden_states, global_scales, deep_ep_topk_idx, - deep_ep_topk_weights, deep_ep_handle) + assert self.has_nvfp4 or self.has_w4afp8 or self.has_fp8_qdq, "Low precision combine only supports nvfp4, w4afp8 and fp8 qdq" + precision = "fp8" + global_scales = None + if self.has_nvfp4: + precision = "nvfp4" + global_scales = torch.ops.trtllm.calculate_nvfp4_global_scale( + final_hidden_states, recv_expert_count) + final_hidden_states = self.deep_ep_buffer.low_latency_combine_low_precision( + precision, final_hidden_states, global_scales, + deep_ep_topk_idx, deep_ep_topk_weights, deep_ep_handle) else: final_hidden_states = self.deep_ep_buffer.low_latency_combine( final_hidden_states, deep_ep_topk_idx, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index a0d23afde87..115bd2ce393 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1,12 +1,9 @@ import bisect import contextlib -import copy import functools import gc import inspect import math -import os -import traceback import weakref from abc import ABC, abstractmethod from contextlib import contextmanager @@ -17,16 +14,13 @@ import tensorrt_llm.bindings.internal.userbuffers as ub from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, - str_dtype_to_torch, torch_dtype_to_str, - trace_func) + torch_dtype_to_str, trace_func) from tensorrt_llm.inputs.multimodal import (MultimodalParams, MultimodalRuntimeData) from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraModelConfig from tensorrt_llm.mapping import CpType, Mapping -from tensorrt_llm.models.modeling_utils import QuantAlgo -from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 from ..attention_backend.interface import (AttentionMetadata, AttentionRuntimeFeatures) @@ -40,14 +34,11 @@ from ..distributed.communicator import init_pp_comm from ..expert_statistic import ExpertStatistic from ..metadata import KVCacheParams -from ..model_config import ModelConfig, MoeLoadBalancerConfig -from ..models import AutoModelForCausalLM from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids -from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, - timing) -from ..modules.fused_moe.moe_load_balancer import ( - MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer) +from ..models.modeling_utils import DecoderModelForCausalLM +from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer, + MoeLoadBalancerIterContext) from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, get_spec_metadata, update_spec_config_from_model_config) @@ -56,12 +47,13 @@ from ..utils import (get_model_extra_attrs, set_per_request_piecewise_cuda_graph_flag, set_torch_compiling, with_model_extra_attrs) -from .config import LoadFormat, PyTorchConfig +from .config import PyTorchConfig from .config_utils import is_mla from .cuda_graph_runner import CUDAGraphRunner from .guided_decoder import CapturableGuidedDecoder from .layerwise_nvtx_marker import LayerwiseNvtxMarker from .llm_request import get_draft_token_length +from .model_loader import ModelLoader from .resource_manager import (BaseResourceManager, KVCacheManager, ResourceManager, ResourceManagerType) from .sampler import SampleStateTensors @@ -96,137 +88,6 @@ def warmup(self, resource_manager: ResourceManager) -> None: return -_KV_CACHE_MAP = { - "fp8": QuantAlgo.FP8.value, - "nvfp4": QuantAlgo.NVFP4.value, - "auto": "auto" -} -_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto") - - -def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, - mamba_ssm_cache_dtype: str) -> None: - if mamba_ssm_cache_dtype == "auto": - mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype - else: - mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype) - - config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype - - -def validate_and_set_kv_cache_quant(model_config: ModelConfig, - pyt_kv_cache_dtype: str) -> QuantAlgo: - logger.info( - f'Validating KV Cache config against kv_cache_dtype="{pyt_kv_cache_dtype}"' - ) - # Quantization from hf_quant_config.json - kv_cache_quant = model_config.quant_config.kv_cache_quant_algo - # PyTorch configuration quantization - valid_pyt_quant = bool(pyt_kv_cache_dtype in _VALID_KV_CACHE_DTYPES) - mapped_pyt_quant = _KV_CACHE_MAP.get(pyt_kv_cache_dtype, None) - - # If we're letting the checkpoint dictate the quant with auto, simply - # return and do not modify the checkpoint. - if pyt_kv_cache_dtype == "auto": - logger.info( - f'KV cache quantization set to "{pyt_kv_cache_dtype}". Using ' - "checkpoint KV quantization.") - return - - # If we have an invalid quantization, simply raise an exception. - if not valid_pyt_quant: - raise ValueError( - "Overriding KV cache quantization with an invalid type " - f'"PyTorchConfig.kv_cache_dtype="{pyt_kv_cache_dtype}" ' - f'Accepted types are "{_VALID_KV_CACHE_DTYPES}".') - - # If we get to this point we have a valid quantization setting, but if - # we have an existing setting and it doesn't match we shouldn't proceed. - if kv_cache_quant is not None and mapped_pyt_quant != kv_cache_quant: - raise RuntimeError( - "Attempting to override KV cache quantization " - f'"{kv_cache_quant}" with PyTorchConfig.kv_cache_dtype=' - f'"{pyt_kv_cache_dtype}". You cannot override a checkpoint with a ' - "pre-quantized KV cache that doesn't match.") - - # We have an open ended KV cache in the checkpoint - # and we have a specified override. - model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant - - -def initialize_dummy_weights( - model: torch.nn.Module, - low: float = -1e-3, - high: float = 1e-3, - seed: int = 0, -) -> None: - """ - This is similar to this function in SGLang with a few changes: - https://github.com/sgl-project/sglang/blob/e074e76b31d4fff13e87a455dbc3acdaa92c537a/python/sglang/srt/model_loader/weight_utils.py#L577 - - This method is used to initialize weights with dummy values for testing - models without checkpoints. Unquantized (FP16/BF16/etc) values are generated - from a uniform distribution over the interval (low, high). - - For some quantized types (FP8/NVFP4), torch has no built-in way to generate random values. - We simply generate values uniformly across an interval that has been empirically verified - to not generate NaNs/inf for these. - """ - - def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]: - # These values are not necessarily the largest possible min/max, - # they need to be small enough to avoid NaNs. - if dtype in (torch.float8_e4m3fn, torch.int8): - return (-3.0, 3.0) - - elif dtype == float4_e2m1x2: - # These correspond to bits of 2 packed FP4 values. - # Because we only go up to 64, the high 4 bits will - # always be 0. But this is fine - we just need values - # that won't generate NaNs. - return (0, 64) - - else: - raise NotImplementedError(f"Unknown quantized type: {dtype}.") - - for param in model.state_dict().values(): - generator = torch.Generator(device=param.data.device) - generator.manual_seed(seed) - dtype = param.data.dtype - - if param.data.element_size() < 2: - # We need to do a cast/round since torch doesn't have uniform_ - # support for these dtypes. - tmp_param = torch.empty(param.data.shape, - dtype=torch.float16, - device=param.data.device) - - quant_min, quant_max = _get_random_min_max(dtype) - tmp_param = tmp_param.uniform_(quant_min, - quant_max, - generator=generator) - - param.data.copy_(tmp_param.to(dtype)) - - # Note: no need to to mess with int32 params, these are probably - # constants and not weights. - elif torch.is_floating_point(param): - param.uniform_(low, high, generator=generator) - - -def get_rank_model_storage(model): - total_bytes = 0 - for _, param in model.named_parameters(): - if param.device.type == 'cuda' and param.device.index == torch.cuda.current_device( - ): - total_bytes += param.element_size() * param.nelement() - for _, buf in model.named_buffers(): - if buf.device.type == 'cuda' and buf.device.index == torch.cuda.current_device( - ): - total_bytes += buf.element_size() * buf.nelement() - return total_bytes - - def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], max_batch_size: int, max_num_tokens: int, max_draft_len: int, @@ -280,6 +141,7 @@ def __init__( is_draft_model: bool = False, drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], torch.nn.Module]] = None, + model: Optional[torch.nn.Module] = None, ): self.ub_buffers = None self.batch_size = batch_size @@ -302,21 +164,26 @@ def __init__( self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures( ) - attn_backend = pytorch_backend_config.attn_backend - self.model = self._load_model( - model_path, - mapping=self.mapping, - checkpoint_loader=checkpoint_loader, - attn_backend=attn_backend, - moe_backend=pytorch_backend_config.moe_backend, - moe_disable_finalize_fusion=pytorch_backend_config. - moe_disable_finalize_fusion, - load_format=pytorch_backend_config.load_format, - max_num_tokens=max_num_tokens, - moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens, - moe_load_balancer=pytorch_backend_config.moe_load_balancer, - lora_config=lora_config, - drafting_loop_wrapper=drafting_loop_wrapper) + if model is None: + loader = ModelLoader( + pytorch_backend_config=pytorch_backend_config, + mapping=self.mapping, + spec_config=self.spec_config, + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + lora_config=lora_config, + ) + self.model, moe_load_balancer = loader.load( + checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader) + if isinstance(moe_load_balancer, MoeLoadBalancer): + setattr(self, "moe_load_balancer", moe_load_balancer) + else: + self.model = model + if drafting_loop_wrapper is not None: + self.model = drafting_loop_wrapper(self.model) + self.model_is_wrapped = True + else: + self.model_is_wrapped = False # In case that some tests use stub models and override `_load_model`. if not hasattr(self.model, 'extra_attrs'): self.model.extra_attrs = {} @@ -387,7 +254,8 @@ def __init__( self.is_warmup = False - self.attn_backend = get_attention_backend(attn_backend) + self.attn_backend = get_attention_backend( + pytorch_backend_config.attn_backend) if self.is_spec_decode: self.spec_metadata = None @@ -940,145 +808,6 @@ def __del__(self) -> None: # Release model weights. release_gc() - def _load_model(self, - checkpoint_dir: str, - checkpoint_loader: BaseCheckpointLoader, - load_format: LoadFormat, - max_num_tokens: int, - moe_max_num_tokens: Optional[int] = None, - moe_load_balancer: Optional[MoeLoadBalancerConfig] = None, - lora_config: Optional[LoraConfig] = None, - drafting_loop_wrapper: Optional[Callable[ - [torch.nn.Module], torch.nn.Module]] = None, - **kwargs) -> DecoderModelForCausalLM: - config = checkpoint_loader.load_config( - checkpoint_dir, - trust_remote_code=True, - enable_min_latency=self.pytorch_backend_config.enable_min_latency, - use_cuda_graph=self.pytorch_backend_config.use_cuda_graph, - force_dynamic_quantization=self.pytorch_backend_config. - force_dynamic_quantization, - spec_config=self.spec_config, - max_num_tokens=max_num_tokens, - max_seq_len=self.max_seq_len, - moe_max_num_tokens=moe_max_num_tokens, - moe_load_balancer=moe_load_balancer, - lora_config=lora_config, - allreduce_strategy=self.pytorch_backend_config.allreduce_strategy, - mm_encoder_only=self.pytorch_backend_config.mm_encoder_only, - **kwargs) - - validate_and_set_kv_cache_quant( - config, self.pytorch_backend_config.kv_cache_dtype) - validate_and_set_mamba_ssm_cache_dtype( - config, self.pytorch_backend_config.mamba_ssm_cache_dtype) - - num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) - if num_layers > 0: - config.pretrained_config.num_hidden_layers = num_layers - for sub_config in ["text_config", "vision_config"]: - if hasattr(config.pretrained_config, sub_config): - getattr(config.pretrained_config, - sub_config).num_hidden_layers = num_layers - - with timing("Model init total"), maybe_create_moe_load_balancer( - config, self.mapping) as moe_load_balancer: - - try: - # config will be modified in-place for some models, like Qwen2 - config_copy = copy.deepcopy(config) - with MetaInitMode(): - model = AutoModelForCausalLM.from_config(config_copy) - - memo = dict() - - def init_meta_tensor(t: torch.Tensor): - if t.device != torch.device('meta'): - return t - if t not in memo: - memo[t] = torch.empty_like(t, device='cuda') - return memo[t] - - model._apply(init_meta_tensor) - config = config_copy - - except Exception: - logger.info( - f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n" - ) - model = AutoModelForCausalLM.from_config(config) - - model.to("cuda") - rank_model_storage = get_rank_model_storage(model) - logger.info( - f"Use {rank_model_storage / (1024**3):.2f} GB for model weights." - ) - if load_format == LoadFormat.AUTO: - if hasattr(model, 'llm_checkpoint_dir'): - weights = checkpoint_loader.load_weights( - model.llm_checkpoint_dir) - else: - weights = checkpoint_loader.load_weights(checkpoint_dir) - - weight_mapper = checkpoint_loader.get_initialized_weight_mapper( - model, config) - self._call_load_weights(model.load_weights, weights, - weight_mapper) - - if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( - ): - weights = checkpoint_loader.load_weights( - self.spec_config.speculative_model_dir) - self._call_load_weights(model.load_draft_weights, weights, - weight_mapper) - - elif load_format == LoadFormat.DUMMY: - initialize_dummy_weights(model) - if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( - ): - model.draft_model.load_weights_from_target_model(model) - - elif load_format == LoadFormat.VISION_ONLY: - # Vision weights are already loaded within the model. - logger.info( - "LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights." - ) - - else: - raise NotImplementedError( - f"No load support for load format: {load_format}") - - for module in model.modules(): - if hasattr(module, 'post_load_weights'): - module.post_load_weights() - - if isinstance(moe_load_balancer, MoeLoadBalancer): - setattr(self, "moe_load_balancer", moe_load_balancer) - moe_load_balancer.register_weight_slots_after_to_cuda() - logger.info("moe_load_balancer finalizing model...") - moe_load_balancer.finalize_model() - logger.info("moe_load_balancer finalize model done") - - torch.cuda.current_stream().synchronize() - - if drafting_loop_wrapper is not None: - model = drafting_loop_wrapper(model) - self.model_is_wrapped = True - else: - self.model_is_wrapped = False - - return model - - def _call_load_weights(self, load_method, weights, weight_mapper): - # TODO smor- this is a temporary solution to load weights. - # Once checkpoint format is unified, this method will be removed. - from inspect import getfullargspec - args = getfullargspec(load_method).args - if "weight_mapper" in args: - load_method(weights, weight_mapper=weight_mapper) - else: - load_method(weights) - def _init_max_seq_len(self): # For mm_encoder_only mode, infer_max_seq_len() is for LLM decoder models if hasattr(self.model, 'infer_max_seq_len'): diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py new file mode 100644 index 00000000000..eb3618dabb4 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -0,0 +1,330 @@ +import copy +import inspect +import os +import traceback +from typing import Callable, Optional, Tuple + +import torch + +from tensorrt_llm._utils import str_dtype_to_torch +from tensorrt_llm.logger import logger +from tensorrt_llm.lora_helper import LoraConfig +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantAlgo +from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 + +from ..model_config import ModelConfig +from ..models import AutoModelForCausalLM +from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader +from ..models.modeling_utils import MetaInitMode, timing +from ..modules.fused_moe.moe_load_balancer import ( + MoeLoadBalancer, maybe_create_moe_load_balancer) +from .config import LoadFormat, PyTorchConfig + +_KV_CACHE_MAP = { + "fp8": QuantAlgo.FP8.value, + "nvfp4": QuantAlgo.NVFP4.value, + "auto": "auto" +} +_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto") + + +def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, + mamba_ssm_cache_dtype: str) -> None: + if mamba_ssm_cache_dtype == "auto": + mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype + else: + mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype) + + config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype + + +def validate_and_set_kv_cache_quant(model_config: ModelConfig, + pyt_kv_cache_dtype: str) -> QuantAlgo: + logger.info( + f'Validating KV Cache config against kv_cache_dtype="{pyt_kv_cache_dtype}"' + ) + # Quantization from hf_quant_config.json + kv_cache_quant = model_config.quant_config.kv_cache_quant_algo + # PyTorch configuration quantization + valid_pyt_quant = bool(pyt_kv_cache_dtype in _VALID_KV_CACHE_DTYPES) + mapped_pyt_quant = _KV_CACHE_MAP.get(pyt_kv_cache_dtype, None) + + # If we're letting the checkpoint dictate the quant with auto, simply + # return and do not modify the checkpoint. + if pyt_kv_cache_dtype == "auto": + logger.info( + f'KV cache quantization set to "{pyt_kv_cache_dtype}". Using ' + "checkpoint KV quantization.") + return + + # If we have an invalid quantization, simply raise an exception. + if not valid_pyt_quant: + raise ValueError( + "Overriding KV cache quantization with an invalid type " + f'"PyTorchConfig.kv_cache_dtype="{pyt_kv_cache_dtype}" ' + f'Accepted types are "{_VALID_KV_CACHE_DTYPES}".') + + # If we get to this point we have a valid quantization setting, but if + # we have an existing setting and it doesn't match we shouldn't proceed. + if kv_cache_quant is not None and mapped_pyt_quant != kv_cache_quant: + raise RuntimeError( + "Attempting to override KV cache quantization " + f'"{kv_cache_quant}" with PyTorchConfig.kv_cache_dtype=' + f'"{pyt_kv_cache_dtype}". You cannot override a checkpoint with a ' + "pre-quantized KV cache that doesn't match.") + + # We have an open ended KV cache in the checkpoint + # and we have a specified override. + model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 0, +) -> None: + """ + This is similar to this function in SGLang with a few changes: + https://github.com/sgl-project/sglang/blob/e074e76b31d4fff13e87a455dbc3acdaa92c537a/python/sglang/srt/model_loader/weight_utils.py#L577 + This method is used to initialize weights with dummy values for testing + models without checkpoints. Unquantized (FP16/BF16/etc) values are generated + from a uniform distribution over the interval (low, high). + For some quantized types (FP8/NVFP4), torch has no built-in way to generate random values. + We simply generate values uniformly across an interval that has been empirically verified + to not generate NaNs/inf for these. + """ + + def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]: + # These values are not necessarily the largest possible min/max, + # they need to be small enough to avoid NaNs. + if dtype in (torch.float8_e4m3fn, torch.int8): + return (-3.0, 3.0) + + elif dtype == float4_e2m1x2: + # These correspond to bits of 2 packed FP4 values. + # Because we only go up to 64, the high 4 bits will + # always be 0. But this is fine - we just need values + # that won't generate NaNs. + return (0, 64) + + else: + raise NotImplementedError(f"Unknown quantized type: {dtype}.") + + for param in model.state_dict().values(): + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + dtype = param.data.dtype + + if param.data.element_size() < 2: + # We need to do a cast/round since torch doesn't have uniform_ + # support for these dtypes. + tmp_param = torch.empty(param.data.shape, + dtype=torch.float16, + device=param.data.device) + + quant_min, quant_max = _get_random_min_max(dtype) + tmp_param = tmp_param.uniform_(quant_min, + quant_max, + generator=generator) + + param.data.copy_(tmp_param.to(dtype)) + + # Note: no need to to mess with int32 params, these are probably + # constants and not weights. + elif torch.is_floating_point(param): + param.uniform_(low, high, generator=generator) + + +def get_rank_model_storage(model): + total_bytes = 0 + for _, param in model.named_parameters(): + if param.device.type == 'cuda' and param.device.index == torch.cuda.current_device( + ): + total_bytes += param.element_size() * param.nelement() + for _, buf in model.named_buffers(): + if buf.device.type == 'cuda' and buf.device.index == torch.cuda.current_device( + ): + total_bytes += buf.element_size() * buf.nelement() + return total_bytes + + +class ModelLoader: + """ + Handles the loading, configuration, and weight initialization of a PyTorch model. + This class isolates model loading logic from the main execution engine. + """ + + def __init__(self, + pytorch_backend_config: PyTorchConfig, + mapping: Mapping, + spec_config: Optional["DecodingBaseConfig"], + max_num_tokens: int, + max_seq_len: Optional[int], + lora_config: Optional[LoraConfig] = None): + """ + Initializes the ModelLoader. + + Args: + pytorch_backend_config: Configuration for the PyTorch backend. + mapping: The distributed mapping configuration. + spec_config: Configuration for speculative decoding. + max_num_tokens: The maximum number of tokens the engine will handle. + max_seq_len: The maximum sequence length. + lora_config: Configuration for LoRA. + """ + self.pytorch_backend_config = pytorch_backend_config + self.mapping = mapping + self.spec_config = spec_config + self.max_num_tokens = max_num_tokens + self.max_seq_len = max_seq_len + self.lora_config = lora_config + + def load( + self, + checkpoint_dir: str, + checkpoint_loader: BaseCheckpointLoader, + ): + """ + Loads the model, its weights, and applies necessary configurations. + + Args: + checkpoint_dir: The directory of the model checkpoint. + checkpoint_loader: The loader object for model checkpoints. + + Returns: + The loaded and initialized PyTorch model. + """ + config = self._load_and_validate_config(checkpoint_dir, + checkpoint_loader) + load_format = self.pytorch_backend_config.load_format + + with timing("Model init total"), maybe_create_moe_load_balancer( + config, self.mapping) as moe_load_balancer: + try: + # config will be modified in-place for some models, like Qwen2 + config_copy = copy.deepcopy(config) + with MetaInitMode(): + model = AutoModelForCausalLM.from_config(config_copy) + + memo = dict() + + def init_meta_tensor(t: torch.Tensor): + if t.device != torch.device('meta'): + return t + if t not in memo: + memo[t] = torch.empty_like(t, device='cuda') + return memo[t] + + model._apply(init_meta_tensor) + config = config_copy + + except Exception: + logger.info( + f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n" + ) + model = AutoModelForCausalLM.from_config(config) + + model.to("cuda") + rank_model_storage = get_rank_model_storage(model) + logger.info( + f"Use {rank_model_storage / (1024**3):.2f} GB for model weights." + ) + if load_format == LoadFormat.AUTO: + if hasattr(model, 'llm_checkpoint_dir'): + weights = checkpoint_loader.load_weights( + model.llm_checkpoint_dir) + else: + weights = checkpoint_loader.load_weights(checkpoint_dir) + + weight_mapper = checkpoint_loader.get_initialized_weight_mapper( + model, config) + self._call_load_weights(model.load_weights, weights, + weight_mapper) + + if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( + ): + weights = checkpoint_loader.load_weights( + self.spec_config.speculative_model_dir) + self._call_load_weights(model.load_draft_weights, weights, + weight_mapper) + + elif load_format == LoadFormat.DUMMY: + initialize_dummy_weights(model) + if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( + ): + model.draft_model.load_weights_from_target_model(model) + + elif load_format == LoadFormat.VISION_ONLY: + # Vision weights are already loaded within the model. + logger.info( + "LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights." + ) + + else: + raise NotImplementedError( + f"No load support for load format: {load_format}") + + for module in model.modules(): + if hasattr(module, 'post_load_weights'): + module.post_load_weights() + + if isinstance(moe_load_balancer, MoeLoadBalancer): + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") + + torch.cuda.current_stream().synchronize() + + return model, moe_load_balancer + + def _load_and_validate_config( + self, checkpoint_dir: str, + checkpoint_loader: BaseCheckpointLoader) -> ModelConfig: + """Loads and validates the model configuration.""" + config = checkpoint_loader.load_config( + checkpoint_dir, + trust_remote_code=True, + mapping=self.mapping, + enable_min_latency=self.pytorch_backend_config.enable_min_latency, + use_cuda_graph=self.pytorch_backend_config.use_cuda_graph, + force_dynamic_quantization=self.pytorch_backend_config. + force_dynamic_quantization, + spec_config=self.spec_config, + max_num_tokens=self.max_num_tokens, + max_seq_len=self.max_seq_len, + moe_max_num_tokens=self.pytorch_backend_config.moe_max_num_tokens, + moe_load_balancer=self.pytorch_backend_config.moe_load_balancer, + lora_config=self.lora_config, + allreduce_strategy=self.pytorch_backend_config.allreduce_strategy, + mm_encoder_only=self.pytorch_backend_config.mm_encoder_only, + attn_backend=self.pytorch_backend_config.attn_backend, + moe_backend=self.pytorch_backend_config.moe_backend, + moe_disable_finalize_fusion=self.pytorch_backend_config. + moe_disable_finalize_fusion) + + validate_and_set_kv_cache_quant( + config, self.pytorch_backend_config.kv_cache_dtype) + validate_and_set_mamba_ssm_cache_dtype( + config, self.pytorch_backend_config.mamba_ssm_cache_dtype) + + # Allow overriding the number of layers via environment variable + num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", + "0")) + if num_layers_override > 0: + config.pretrained_config.num_hidden_layers = num_layers_override + for sub_config in ["text_config", "vision_config"]: + if hasattr(config.pretrained_config, sub_config): + getattr(config.pretrained_config, + sub_config).num_hidden_layers = num_layers_override + return config + + def _call_load_weights(self, load_method: Callable, weights, weight_mapper): + """Calls the model's weight loading method with the correct arguments.""" + args = inspect.getfullargspec(load_method).args + if "weight_mapper" in args: + load_method(weights, weight_mapper=weight_mapper) + else: + load_method(weights) diff --git a/tensorrt_llm/bench/benchmark/utils/general.py b/tensorrt_llm/bench/benchmark/utils/general.py index ff3cd933ce1..a21511f38cd 100755 --- a/tensorrt_llm/bench/benchmark/utils/general.py +++ b/tensorrt_llm/bench/benchmark/utils/general.py @@ -8,7 +8,7 @@ import yaml -from tensorrt_llm._torch.pyexecutor.model_engine import \ +from tensorrt_llm._torch.pyexecutor.model_loader import \ validate_and_set_kv_cache_quant from tensorrt_llm.bench.build.build import (get_benchmark_engine_settings, get_model_config) diff --git a/tensorrt_llm/bench/dataclasses/reporting.py b/tensorrt_llm/bench/dataclasses/reporting.py index b12873b5637..70e4cae646b 100755 --- a/tensorrt_llm/bench/dataclasses/reporting.py +++ b/tensorrt_llm/bench/dataclasses/reporting.py @@ -4,7 +4,7 @@ from collections import defaultdict from typing import Any, Dict, List, NamedTuple -from tensorrt_llm._torch.pyexecutor.model_engine import \ +from tensorrt_llm._torch.pyexecutor.model_loader import \ validate_and_set_kv_cache_quant from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig from tensorrt_llm.bench.dataclasses.general import DatasetMetadata diff --git a/tests/integration/test_lists/qa/llm_perf_cluster_nim.yml b/tests/integration/test_lists/qa/llm_perf_cluster_nim.yml new file mode 100644 index 00000000000..e56252fd7e3 --- /dev/null +++ b/tests/integration/test_lists/qa/llm_perf_cluster_nim.yml @@ -0,0 +1,141 @@ +version: 0.0.1 +llm_perf_cluster_nim: +- condition: + ranges: + system_gpu_count: + gte: 1 + tests: + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-pytorch-bfloat16-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-maxbs:256-input_output_len:128,128-quant:fp8] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-maxbs:256-input_output_len:512,32-quant:fp8] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:2000,500] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-streaming-float8-input_output_len:2000,500] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:500,2000] + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:1000,1000-quant:fp8] + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:500,2000-quant:fp8] + - perf/test_perf.py::test_perf[deepseek_v3_lite_nvfp4-bench-pytorch-streaming-float4-maxbs:2048-maxnt:8192-input_output_len:256,256-reqs:200] + # for chunked prefill cases + - perf/test_perf.py::test_perf[deepseek_v3_lite_nvfp4-bench-pytorch-float4-maxbs:512-maxnt:2048-kv_frac:0.85-input_output_len:5000,500-reqs:200] + # Phi-4-multimodal-instruct + - perf/test_perf.py::test_perf[phi_4_multimodal_instruct-bench-pytorch-bfloat16-input_output_len:500,2000-con:250] + - perf/test_perf.py::test_perf[phi_4_multimodal_instruct-bench-pytorch-bfloat16-input_output_len:1000,1000-con:250] + - perf/test_perf.py::test_perf[phi_4_multimodal_instruct-bench-pytorch-bfloat16-input_output_len:128,128] + - perf/test_perf.py::test_perf[phi_4_multimodal_instruct-bench-pytorch-bfloat16-input_output_len:512,32] + - perf/test_perf.py::test_perf[phi_4_multimodal_instruct_image-bench-pytorch-bfloat16-input_output_len:1000,1000-loras:1-con:250] + - perf/test_perf.py::test_perf[phi_4_multimodal_instruct_audio-bench-pytorch-bfloat16-input_output_len:1000,1000-loras:1-con:250] + #Mistral-Small-3.1-24B-Instruct-2503 + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-input_output_len:1000,2000-reqs:8-con:1] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-input_output_len:1000,2000-reqs:500-con:200] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1] TIMEOUT(120) + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:4096-maxnt:20000-input_output_len:20000,2000-reqs:300-con:200] TIMEOUT(120) + + +- condition: + ranges: + system_gpu_count: + gte: 2 + tests: + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-input_output_len:128,128-quant:nvfp4-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-float8-maxbs:256-input_output_len:128,128-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-float8-maxbs:256-input_output_len:512,32-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct_fp8-bench-pytorch-float8-maxbs:256-input_output_len:128,128-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-maxbs:256-input_output_len:512,32-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-streaming-float8-maxbs:256-input_output_len:512,32-gpus:2] + - perf/test_perf.py::test_perf[llama_v2_13b-bench-float16-input_output_len:128,128-loras:8-gpus:2] + #Mistral-Small-3.1-24B-Instruct-2503 + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-input_output_len:1000,2000-reqs:8-con:1-gpus:2] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-input_output_len:1000,2000-reqs:500-con:200-gpus:2] + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:1-maxnt:20000-input_output_len:20000,2000-reqs:8-con:1-gpus:2] TIMEOUT(120) + - perf/test_perf.py::test_perf[mistral_small_v3.1_24b-bench-pytorch-bfloat16-maxbs:4096-maxnt:20000-input_output_len:20000,2000-reqs:300-con:200-gpus:2] TIMEOUT(120) + +# Tests for systems with 4+ GPUs +- condition: + ranges: + system_gpu_count: + gte: 4 + tests: + - perf/test_perf.py::test_perf[starcoder_15b-bench-float16-input_output_len:512,200-gpus:4] + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:512-input_output_len:128,128-ep:4-tp:4-gpus:4] + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-streaming-float4-maxbs:512-input_output_len:128,128-ep:4-tp:4-gpus:4] + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,1000-reqs:2000-ep:4-tp:4-gpus:4] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,2000-reqs:3000-ep:4-tp:4-gpus:4] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:1000-maxnt:5000-kv_frac:0.85-input_output_len:5000,500-reqs:20000-ep:4-tp:4-gpus:4] TIMEOUT(120) + # for chunked prefill cases + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:512-maxnt:2048-kv_frac:0.85-input_output_len:5000,500-reqs:200-ep:4-tp:4-gpus:4] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:256-maxnt:1024-kv_frac:0.85-input_output_len:2000,2000-reqs:200-ep:4-tp:4-gpus:4] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:1000-maxnt:5000-kv_frac:0.85-input_output_len:5000,500-reqs:2000-ep:4-tp:4-gpus:4] TIMEOUT(120) + - perf/test_perf.py::test_perf[qwen3_235b_a22b_fp4-bench-pytorch-float4-input_output_len:1000,2000-con:512-ep:4-gpus:4] + #llama_v3.1_405b_instruct_fp4 + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,1000-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1024,2048-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-maxbs:4096-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-gpus:4] TIMEOUT(120) + #llama_v3.3_70b_instruct_fp4 + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-input_output_len:512,32-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:1000-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:4096-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:200-gpus:4] TIMEOUT(120) + #llama_v4_scout_17b_16e_instruct_fp4 + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:500,2000-reqs:500-gpus:4] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:500-gpus:4] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:4096-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:200-gpus:4] TIMEOUT(120) + + +# Tests for systems with 8+ GPUs +- condition: + ranges: + system_gpu_count: + gte: 8 + tests: + #llama_v3.3_nemotron_super_49b + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-input_output_len:500,2000-quant:fp8-con:250-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-bfloat16-input_output_len:500,2000-con:250-gpus:8] + #llama_v3.3_70b_instruct_fp4 + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:500,2000-reqs:3000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:3000-tp:8-gpus:8] + + #llama_v4_scout_17b_16e_instruct_fp4 + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:128,128-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:512,32-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:500,2000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[gpt_20b-bench-float16-maxbs:8-input_output_len:128,128-reqs:80-gpus:8] + - perf/test_perf.py::test_perf[gpt_20b-bench-float16-maxbs:8-input_output_len:512,32-reqs:80-gpus:8] + #deepseek_r1_fp8 + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:512-input_output_len:128,128-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:1-input_output_len:1000,2000-reqs:10-ep:4-tp:8-gpus:8] #min latency test + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:384-maxnt:1536-input_output_len:1000,2000-reqs:49152-con:3072-ep:8-tp:8-gpus:8] #max throughput test + #deepseek_r1_nvfp4 + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:512-input_output_len:128,128-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:1-input_output_len:1000,2000-reqs:10-ep:4-tp:8-gpus:8] #min latency test + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-streaming-float4-maxbs:1-input_output_len:1000,2000-reqs:10-ep:4-tp:8-gpus:8] #min latency test + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:384-maxnt:1536-input_output_len:1000,2000-reqs:49152-con:3072-ep:8-tp:8-gpus:8] TIMEOUT (120) #max throughput test + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-streaming-float4-maxbs:384-maxnt:1536-input_output_len:1000,2000-reqs:49152-con:3072-ep:8-tp:8-gpus:8] #max throughput test + # for chunked prefill cases + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:512-maxnt:2048-kv_frac:0.85-input_output_len:5000,500-reqs:200-ep:8-tp:8-gpus:8] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:256-maxnt:1024-kv_frac:0.85-input_output_len:2000,2000-reqs:200-ep:8-tp:8-gpus:8] TIMEOUT(120) + #deepseek_r1_0528_fp4 + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,1000-reqs:20000-ep:8-tp:8-gpus:8] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,2000-reqs:3000-ep:8-tp:8-gpus:8] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp4-bench-pytorch-float4-maxbs:1000-maxnt:5000-kv_frac:0.85-input_output_len:5000,500-reqs:20000-ep:4-tp:4-gpus:4] TIMEOUT(120) + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct-bench-pytorch-bfloat16-input_output_len:128,128-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct-bench-pytorch-bfloat16-input_output_len:500,2000-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct-bench-pytorch-bfloat16-input_output_len:2000,500-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct-bench-pytorch-streaming-bfloat16-input_output_len:2000,500-ep:8-tp:8-gpus:8] TIMEOUT (40) + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:500,2000-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:2000,500-ep:8-tp:8-gpus:8] TIMEOUT (40) + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-streaming-float8-input_output_len:2000,500-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[qwen3_235b_a22b_fp4-bench-pytorch-float4-input_output_len:1000,2000-con:8-ep:8-tp:8-gpus:8] + #gpt_oss_120b + # max throughput test + - perf/test_perf.py::test_perf[gpt_oss_120b_fp4-bench-pytorch-float4-maxbs:720-maxnt:16384-input_output_len:1024,1024-reqs:1280-con:256-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[gpt_oss_120b_fp4-bench-pytorch-float4-maxbs:720-maxnt:16384-input_output_len:1024,1024-reqs:2560-con:512-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[gpt_oss_120b_fp4-bench-pytorch-float4-maxbs:720-maxnt:16384-input_output_len:1024,1024-reqs:5120-con:1024-ep:8-tp:8-gpus:8] TIMEOUT(120) + - perf/test_perf.py::test_perf[gpt_oss_120b_fp4-bench-pytorch-float4-maxbs:720-maxnt:16384-input_output_len:1024,1024-reqs:20480-con:4096-ep:8-tp:8-gpus:8] TIMEOUT(180) + # min latency test + - perf/test_perf.py::test_perf[gpt_oss_120b_fp4-bench-pytorch-float4-maxbs:720-maxnt:16384-input_output_len:1024,1024-reqs:8-con:1-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[gpt_oss_120b_fp4-bench-pytorch-float4-maxbs:720-maxnt:16384-input_output_len:1024,1024-reqs:100-con:32-ep:8-tp:8-gpus:8] diff --git a/tests/integration/test_lists/qa/llm_perf_nim.yml b/tests/integration/test_lists/qa/llm_perf_nim.yml index 9b436e58136..a0b8b40302b 100644 --- a/tests/integration/test_lists/qa/llm_perf_nim.yml +++ b/tests/integration/test_lists/qa/llm_perf_nim.yml @@ -150,7 +150,6 @@ llm_perf_nim: - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:128,128-gpus:2] #t5 - perf/test_perf.py::test_perf[t5-bench-float16-input_output_len:128,20-gpus:2] - - perf/test_perf.py::test_perf[flan_t5_large-bench-float16-input_output_len:128,20-gpus:2] - condition: ranges: @@ -168,7 +167,6 @@ llm_perf_nim: #llama_v3.1_70b #trt backend - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:1024,1024-tp:2-gpus:2] - - perf/test_perf.py::test_perf[llama_70b_sq_per_tensor-cppmanager-exe-plugin_ifb-float16-input_output_len:128,128+512,32-gpus:2] #mixtral_8x7b_v0.1 #trt backend - perf/test_perf.py::test_perf[mixtral_8x7b-cppmanager-exe-plugin_ifb-float16-mp-input_output_len:128,128+512,32-gpus:2] @@ -199,6 +197,8 @@ llm_perf_nim: #trt backend - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:1000,1000-quant:fp8-tp:2] - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:500,2000-quant:fp8-tp:2] + # torch backend + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-pytorch-float16-input_output_len:128,128] #phi_3_mini_128k_instruct #trt backend - perf/test_perf.py::test_perf[phi_3_mini_128k_instruct-bench-float16-maxbs:128-input_output_len:1000,1000-quant:fp8-tp:2] diff --git a/tests/unittest/_torch/executor/test_pytorch_model_engine.py b/tests/unittest/_torch/executor/test_pytorch_model_engine.py index 8a06a3a9f0f..ec53a1ae832 100644 --- a/tests/unittest/_torch/executor/test_pytorch_model_engine.py +++ b/tests/unittest/_torch/executor/test_pytorch_model_engine.py @@ -67,16 +67,14 @@ def __init__(self, mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(), tp_size=tensorrt_llm.mpi_world_size(), rank=tensorrt_llm.mpi_rank()) - self.model_is_wrapped = False + model = DummyModel(self.dtype) super().__init__(model_path="", pytorch_backend_config=pytorch_backend_config, checkpoint_loader=None, batch_size=batch_size, max_seq_len=max_seq_len, - mapping=mapping) - - def _load_model(self, mode_path: str, **kwargs) -> torch.nn.Module: - return DummyModel(self.dtype) + mapping=mapping, + model=model) def _create_request(num_tokens, req_id: int):