diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6928347baa..7de6ddb0d0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -98,6 +98,7 @@ def cross_compile_for_windows( enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -173,6 +174,7 @@ def cross_compile_for_windows( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -332,6 +334,7 @@ def cross_compile_for_windows( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "use_distributed_mode_trace": use_distributed_mode_trace, } # disable the following settings is not supported for cross compilation for windows feature @@ -421,6 +424,7 @@ def compile( enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -498,6 +502,7 @@ def compile( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -674,6 +679,7 @@ def compile( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) @@ -964,6 +970,7 @@ def convert_exported_program_to_serialized_trt_engine( enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1029,6 +1036,7 @@ def convert_exported_program_to_serialized_trt_engine( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -1147,6 +1155,7 @@ def convert_exported_program_to_serialized_trt_engine( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index bcb8495c67..46db1c0d3b 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1002,66 +1002,3 @@ def args_bounds_check( args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None ) -> Any: return args[i] if len(args) > i and args[i] is not None else replacement - - -def load_tensorrt_llm() -> bool: - """ - Attempts to load the TensorRT-LLM plugin and initialize it. - - Returns: - bool: True if the plugin was successfully loaded and initialized, False otherwise. - """ - try: - import tensorrt_llm as trt_llm # noqa: F401 - - _LOGGER.info("TensorRT-LLM successfully imported") - return True - except (ImportError, AssertionError) as e_import_error: - # Check for environment variable for the plugin library path - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") - if not plugin_lib_path: - _LOGGER.warning( - "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", - ) - return False - - _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") - try: - # Load the shared library - handle = ctypes.CDLL(plugin_lib_path) - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") - except OSError as e_os_error: - _LOGGER.error( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" - f"Ensure the path is correct and the library is compatible", - exc_info=e_os_error, - ) - return False - - try: - # Configure plugin initialization arguments - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - handle.initTrtLlmPlugins.restype = ctypes.c_bool - except AttributeError as e_plugin_unavailable: - _LOGGER.warning( - "Unable to initialize the TensorRT-LLM plugin library", - exc_info=e_plugin_unavailable, - ) - return False - - try: - # Initialize the plugin - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): - _LOGGER.info("TensorRT-LLM plugin successfully initialized") - return True - else: - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") - return False - except Exception as e_initialization_error: - _LOGGER.warning( - "Exception occurred during TensorRT-LLM plugin library initialization", - exc_info=e_initialization_error, - ) - return False - return False diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 79611c7552..3e67457e54 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -11,11 +11,11 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, ) -from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_reduce_scatter_op, ) +from torch_tensorrt.dynamo.utils import load_tensorrt_llm _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index e4018ae95c..a8b1b10e3a 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,11 +1,18 @@ from __future__ import annotations +import ctypes import gc import logging +import os +import shutil +import subprocess +import sys +import urllib.request import warnings from dataclasses import fields, replace from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from urllib.error import URLError import numpy as np import sympy @@ -14,7 +21,7 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device -from torch_tensorrt._enums import dtype +from torch_tensorrt._enums import Platform, dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults @@ -812,3 +819,126 @@ def is_tegra_platform() -> bool: if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: return True return False + + +def download_plugin_lib_path(py_version: str, platform: str) -> str: + plugin_lib_path = None + + # Downloading TRT-LLM lib + # TODO: check how to fix the 0.18.0 hardcode below + base_url = "https://pypi.nvidia.com/tensorrt-llm/" + file_name = f"tensorrt_llm-0.18.0-{py_version}-{py_version}-{platform}.whl" + download_url = base_url + file_name + if not (os.path.exists(file_name)): + try: + logger.debug(f"Downloading {download_url} ...") + urllib.request.urlretrieve(download_url, file_name) + logger.debug("Download succeeded and TRT-LLM wheel is now present") + except urllib.error.HTTPError as e: + logger.error( + f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" + ) + except urllib.error.URLError as e: + logger.error( + f"URL error when trying to download {download_url}: {e.reason}" + ) + except OSError as e: + logger.error(f"Local file write error: {e}") + + # Proceeding with the unzip of the wheel file + # This will exist if the filename was already downloaded + if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"): + plugin_lib_path = "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so" + else: + try: + import zipfile + except: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + with zipfile.ZipFile(file_name, "r") as zip_ref: + zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm' + plugin_lib_path = ( + "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so" + ) + return plugin_lib_path + + +def load_tensorrt_llm() -> bool: + """ + Attempts to load the TensorRT-LLM plugin and initialize it. + Either the env variable TRTLLM_PLUGINS_PATH can specify the path + Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it + + Returns: + bool: True if the plugin was successfully loaded and initialized, False otherwise. + """ + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + if not plugin_lib_path: + # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if not use_trtllm_plugin: + logger.warning( + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" + ) + return False + else: + # this is used as the default py version + py_version = f"cp312" + platform = Platform.current_platform() + + platform = str(platform).lower() + plugin_lib_path = download_plugin_lib_path(py_version, platform) + + try: + # Load the shared TRT-LLM file + handle = ctypes.CDLL(plugin_lib_path) + logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + if "libmpi" in str(e_os_error): + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " + f"The dependency libmpi.so is missing. " + f"Please install the packages libmpich-dev and libopenmpi-dev.", + exc_info=e_os_error, + ) + else: + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" + f"Ensure the path is correct and the library is compatible", + exc_info=e_os_error, + ) + return False + + try: + # Configure plugin initialization arguments + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + logger.warning( + "Unable to initialize the TensorRT-LLM plugin library", + exc_info=e_plugin_unavailable, + ) + return False + + try: + # Initialize the plugin + TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" + if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): + logger.info("TensorRT-LLM plugin successfully initialized") + return True + else: + logger.warning("TensorRT-LLM plugin library failed in initialization") + return False + except Exception as e_initialization_error: + logger.warning( + "Exception occurred during TensorRT-LLM plugin library initialization", + exc_info=e_initialization_error, + ) + return False + return False