diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index 492035a76f..4f83678265 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,2 +1,3 @@ __cuda_version__: "12.8" __tensorrt_version__: "10.11.0" +__tensorrt_llm_version__: "0.17.0.post1" diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index ff7d3b7a07..1d44b49874 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -103,6 +103,7 @@ def cross_compile_for_windows( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -177,6 +178,7 @@ def cross_compile_for_windows( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -339,6 +341,7 @@ def cross_compile_for_windows( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "use_distributed_mode_trace": use_distributed_mode_trace, } # disable the following settings is not supported for cross compilation for windows feature @@ -439,6 +442,7 @@ def compile( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -516,6 +520,7 @@ def compile( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -688,6 +693,7 @@ def compile( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) @@ -1052,6 +1058,7 @@ def convert_exported_program_to_serialized_trt_engine( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1116,6 +1123,7 @@ def convert_exported_program_to_serialized_trt_engine( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -1238,6 +1246,7 @@ def convert_exported_program_to_serialized_trt_engine( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 896bf37b42..1ca1b33caf 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,8 +1,6 @@ import collections -import ctypes import functools import logging -import os from typing import ( Any, Callable, @@ -1117,69 +1115,6 @@ def args_bounds_check( return args[i] if len(args) > i and args[i] is not None else replacement -def load_tensorrt_llm() -> bool: - """ - Attempts to load the TensorRT-LLM plugin and initialize it. - - Returns: - bool: True if the plugin was successfully loaded and initialized, False otherwise. - """ - try: - import tensorrt_llm as trt_llm # noqa: F401 - - _LOGGER.info("TensorRT-LLM successfully imported") - return True - except (ImportError, AssertionError) as e_import_error: - # Check for environment variable for the plugin library path - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") - if not plugin_lib_path: - _LOGGER.warning( - "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", - ) - return False - - _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") - try: - # Load the shared library - handle = ctypes.CDLL(plugin_lib_path) - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") - except OSError as e_os_error: - _LOGGER.error( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" - f"Ensure the path is correct and the library is compatible", - exc_info=e_os_error, - ) - return False - - try: - # Configure plugin initialization arguments - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - handle.initTrtLlmPlugins.restype = ctypes.c_bool - except AttributeError as e_plugin_unavailable: - _LOGGER.warning( - "Unable to initialize the TensorRT-LLM plugin library", - exc_info=e_plugin_unavailable, - ) - return False - - try: - # Initialize the plugin - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): - _LOGGER.info("TensorRT-LLM plugin successfully initialized") - return True - else: - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") - return False - except Exception as e_initialization_error: - _LOGGER.warning( - "Exception occurred during TensorRT-LLM plugin library initialization", - exc_info=e_initialization_error, - ) - return False - return False - - def promote_trt_tensors_to_same_dtype( ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str ) -> tuple[TRTTensor, TRTTensor]: diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 79611c7552..aecc99b1f1 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -11,15 +11,15 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, ) -from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_reduce_scatter_op, ) +from torch_tensorrt.dynamo.utils import load_tensorrt_llm_for_nccl _LOGGER: logging.Logger = logging.getLogger(__name__) -if load_tensorrt_llm(): +if load_tensorrt_llm_for_nccl(): @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) def fused_nccl_gather( diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 0703fd1cb9..701c920353 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,11 +1,26 @@ from __future__ import annotations +import ctypes import gc +import getpass import logging +import os +import tempfile +import urllib.request import warnings from dataclasses import fields, replace from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import sympy @@ -14,9 +29,10 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device -from torch_tensorrt._enums import dtype +from torch_tensorrt._enums import Platform, dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input +from torch_tensorrt._version import __tensorrt_llm_version__ from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._engine_cache import BaseEngineCache @@ -33,6 +49,7 @@ RTOL = 5e-3 ATOL = 5e-3 CPU_DEVICE = "cpu" +_WHL_CPYTHON_VERSION = "cp310" class Frameworks(Enum): @@ -820,3 +837,215 @@ def is_tegra_platform() -> bool: if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: return True return False + + +def is_platform_supported_for_trtllm(platform: str) -> bool: + """ + Checks if the current platform supports TensorRT-LLM plugins for NCCL backend + Returns: + bool: True if the platform supports TensorRT-LLM plugins for NCCL backend, False otherwise. + Note: + TensorRT-LLM plugins for NCCL backend are not supported on: + - Windows platforms + - Jetson devices (aarch64 architecture) + """ + if "windows" in platform: + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Windows" + ) + return False + if "aarch64" in platform: + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices (aarch64)" + ) + return False + return True + + +def _cache_root() -> Path: + username = getpass.getuser() + return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" + + +def _extracted_dir_trtllm(platform: str) -> Path: + return _cache_root() / "trtllm" / f"{__tensorrt_llm_version__}_{platform}" + + +def download_and_get_plugin_lib_path(platform: str) -> Optional[str]: + """ + Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. + + Args: + platform (str): Platform identifier (e.g., 'linux_x86_64') + + Returns: + Optional[str]: Path to shared library or None if operation fails. + """ + wheel_filename = ( + f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-" + f"{_WHL_CPYTHON_VERSION}-{platform}.whl" + ) + wheel_path = _cache_root() / wheel_filename + extract_dir = _extracted_dir_trtllm(platform) + # else will never be met though + lib_filename = ( + "libnvinfer_plugin_tensorrt_llm.so" + if "linux" in platform + else "libnvinfer_plugin_tensorrt_llm.dll" + ) + # eg: /tmp/torch_tensorrt_/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so + plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename + + if plugin_lib_path.exists(): + return str(plugin_lib_path) + + wheel_path.parent.mkdir(parents=True, exist_ok=True) + extract_dir.mkdir(parents=True, exist_ok=True) + + if not wheel_path.exists(): + base_url = "https://pypi.nvidia.com/tensorrt-llm/" + download_url = base_url + wheel_filename + try: + logger.debug(f"Downloading {download_url} ...") + urllib.request.urlretrieve(download_url, wheel_path) + logger.debug("Download succeeded and TRT-LLM wheel is now present") + except urllib.error.HTTPError as e: + logger.error( + f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" + ) + except urllib.error.URLError as e: + logger.error( + f"URL error when trying to download {download_url}: {e.reason}" + ) + except OSError as e: + logger.error(f"Local file write error: {e}") + + try: + import zipfile + except ImportError as e: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + try: + with zipfile.ZipFile(wheel_path) as zip_ref: + zip_ref.extractall(extract_dir) + logger.debug(f"Extracted wheel to {extract_dir}") + except FileNotFoundError as e: + # This should capture the errors in the download failure above + logger.error(f"Wheel file not found at {wheel_path}: {e}") + raise RuntimeError( + f"Failed to find downloaded wheel file at {wheel_path}" + ) from e + except zipfile.BadZipFile as e: + logger.error(f"Invalid or corrupted wheel file: {e}") + raise RuntimeError( + "Downloaded wheel file is corrupted or not a valid zip archive" + ) from e + except Exception as e: + logger.error(f"Unexpected error while extracting wheel: {e}") + raise RuntimeError( + "Unexpected error during extraction of TensorRT-LLM wheel" + ) from e + + try: + wheel_path.unlink(missing_ok=True) + logger.debug(f"Deleted wheel file: {wheel_path}") + except Exception as e: + logger.warning(f"Could not delete wheel file {wheel_path}: {e}") + if not plugin_lib_path.exists(): + logger.error( + f"Plugin library not found at expected location: {plugin_lib_path}" + ) + return None + + return str(plugin_lib_path) + + +def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: + """ + Loads and initializes the TensorRT-LLM plugin from the given shared library path. + + Args: + plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. + + Returns: + bool: True if successful, False otherwise. + """ + try: + handle = ctypes.CDLL(plugin_lib_path) + logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + if "libmpi" in str(e_os_error): + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)", + exc_info=e_os_error, + ) + else: + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " + f"Ensure the path is correct and the library is compatible.", + exc_info=e_os_error, + ) + return False + + try: + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + logger.warning( + "Unable to initialize the TensorRT-LLM plugin library", + exc_info=e_plugin_unavailable, + ) + return False + + try: + if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): + logger.info("TensorRT-LLM plugin successfully initialized") + return True + else: + logger.warning("TensorRT-LLM plugin library failed in initialization") + return False + except Exception as e_initialization_error: + logger.warning( + "Exception occurred during TensorRT-LLM plugin library initialization", + exc_info=e_initialization_error, + ) + return False + return False + + +def load_tensorrt_llm_for_nccl() -> bool: + """ + Attempts to load the TensorRT-LLM plugin and initialize it. + Either the env variable TRTLLM_PLUGINS_PATH can specify the path + Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it + + Returns: + bool: True if the plugin was successfully loaded and initialized, False otherwise. + """ + # Check platform compatibility first + platform = Platform.current_platform() + platform = str(platform).lower() + if not is_platform_supported_for_trtllm(platform): + return False + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + + if plugin_lib_path: + return load_and_initialize_trtllm_plugin(plugin_lib_path) + else: + # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if not use_trtllm_plugin: + logger.warning( + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" + ) + return False + + plugin_lib_path = download_and_get_plugin_lib_path(platform) + return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type] + return False diff --git a/setup.py b/setup.py index fb96d85453..eb4a6120ed 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ __version__: str = "0.0.0" __cuda_version__: str = "0.0" __tensorrt_version__: str = "0.0" +__tensorrt_llm_version__: str = "0.0" LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$") @@ -63,6 +64,7 @@ def get_base_version() -> str: def load_dep_info(): global __cuda_version__ global __tensorrt_version__ + global __tensorrt_llm_version__ with open("dev_dep_versions.yml", "r") as stream: versions = yaml.safe_load(stream) if (gpu_arch_version := os.environ.get("CU_VERSION")) is not None: @@ -72,6 +74,7 @@ def load_dep_info(): else: __cuda_version__ = versions["__cuda_version__"] __tensorrt_version__ = versions["__tensorrt_version__"] + __tensorrt_llm_version__ = versions["__tensorrt_llm_version__"] load_dep_info() @@ -240,6 +243,7 @@ def gen_version_file(): f.write('__version__ = "' + __version__ + '"\n') f.write('__cuda_version__ = "' + __cuda_version__ + '"\n') f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n') + f.write('__tensorrt_llm_version__ = "' + __tensorrt_llm_version__ + '"\n') def copy_libtorchtrt(multilinux=False, rt_only=False): diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py index e3062249fa..bc058aaaec 100644 --- a/tests/py/dynamo/distributed/distributed_utils.py +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -13,7 +13,6 @@ def set_environment_variables_pytest(): os.environ["RANK"] = str(0) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(29500) - os.environ["USE_TRTLLM_PLUGINS"] = "1" def initialize_logger(rank, logger_file_name): diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index 89c94300b7..e8bca66efe 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -1,42 +1,76 @@ import os +import unittest import torch import torch.distributed as dist import torch.nn as nn +from conversion.harness import DispatchTestCase from distributed_utils import set_environment_variables_pytest from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt._enums import Platform -set_environment_variables_pytest() -dist.init_process_group(backend="nccl", init_method="env://") -group = dist.new_group(ranks=[0]) -group_name = group.group_name -world_size = 1 -from conversion.harness import DispatchTestCase +class DistributedGatherModel(nn.Module): + def __init__(self, input_dim, world_size, group_name): + super().__init__() + self.fc = nn.Linear(input_dim, input_dim) + self.world_size = world_size + self.group_name = group_name + + def forward(self, x): + x = self.fc(x) + gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( + x, self.world_size, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor(gathered_tensor) + + +class DistributedReduceScatterModel(nn.Module): + def __init__(self, input_dim, world_size, group_name): + super().__init__() + self.fc = nn.Linear(input_dim, input_dim) + self.world_size = world_size + self.group_name = group_name + + def forward(self, x): + x = self.fc(x) + out = torch.ops._c10d_functional.reduce_scatter_tensor( + x, "sum", self.world_size, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor(out) + + +platform_str = str(Platform.current_platform()).lower() class TestGatherNcclOpsConverter(DispatchTestCase): - @parameterized.expand([8]) - def test_nccl_ops(self, linear_layer_dim): - class DistributedGatherModel(nn.Module): - def __init__(self, input_dim): - super().__init__() - self.fc = torch.nn.Linear(input_dim, input_dim) - - def forward(self, x): - x = self.fc(x) - gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( - x, world_size, group_name - ) - gathered_tensor = torch.ops._c10d_functional.wait_tensor( - gathered_tensor - ) - return gathered_tensor + @unittest.skipIf( + "win" or "aarch64" in platform_str, + "Skipped on Windows and Jetson: NCCL backend is not supported.", + ) + @classmethod + def setUpClass(cls): + set_environment_variables_pytest() + print("USE_TRTLLM_PLUGINS =", os.environ.get("USE_TRTLLM_PLUGINS")) + cls.world_size = 1 + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + cls.group = dist.new_group(ranks=[0]) + cls.group_name = cls.group.group_name + + @classmethod + def tearDownClass(cls): + if dist.is_initialized(): + dist.destroy_process_group() + @parameterized.expand([8]) + def test_nccl_ops_gather(self, linear_layer_dim): inputs = [torch.randn(1, linear_layer_dim).to("cuda")] self.run_test( - DistributedGatherModel(linear_layer_dim).cuda(), + DistributedGatherModel( + linear_layer_dim, self.world_size, self.group_name + ).cuda(), inputs, use_dynamo_tracer=True, enable_passes=True, @@ -44,28 +78,11 @@ def forward(self, x): @parameterized.expand([8]) def test_nccl_ops_scatter(self, linear_layer_dim): - - class DistributedReduceScatterModel(nn.Module): - def __init__(self, input_dim): - super().__init__() - self.fc = torch.nn.Linear(input_dim, input_dim) - - def forward(self, x): - x = self.fc(x) - scatter_reduce_tensor = ( - torch.ops._c10d_functional.reduce_scatter_tensor( - x, "sum", world_size, group_name - ) - ) - scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor( - scatter_reduce_tensor - ) - return scatter_reduce_tensor - inputs = [torch.zeros(1, linear_layer_dim).to("cuda")] - self.run_test( - DistributedReduceScatterModel(linear_layer_dim).cuda(), + DistributedReduceScatterModel( + linear_layer_dim, self.world_size, self.group_name + ).cuda(), inputs, use_dynamo_tracer=True, enable_passes=True, diff --git a/tests/py/dynamo/distributed/test_nccl_ops.sh b/tests/py/dynamo/distributed/test_nccl_ops.sh index dd54700048..677d0cb9bc 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.sh +++ b/tests/py/dynamo/distributed/test_nccl_ops.sh @@ -70,51 +70,6 @@ ensure_pytest_installed(){ echo "Setting up the environment" -OS="$(uname -s)" -ARCH="$(uname -m)" - - -#getting the file name for TensorRT-LLM download -if [[ "$OS" == "Linux" && "$ARCH" == "x86_64"]]; then - FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl" -elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64"]]; then - FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl" -else: - echo "Unsupported platform: OS=$OS ARCH=$ARCH - exit 1 -fi - -# Download the selected file -URL="https://pypi.nvidia.com/tensorrt-llm/$FILE" -echo "Downloading $FILE from $URL..." - -#Installing wget -ensure_installed wget - -#Downloading the file -filename=$(basename "$URL") -if [ -f "$filename" ]; then - echo "File already exists: $filename" -else - wget "$URL" -fi -echo "Download complete: $FILE" - -UNZIP_DIR="tensorrt_llm_unzip" -if [[ ! -d "$UNZIP_DIR" ]]; then - echo "Creating directory: $UNZIP_DIR" - mkdir -p "$UNZIP_DIR" - echo "extracting $FILE to $UNZIP_DIR ..." - #Installing unzip - ensure_installed unzip - #unzip the TensorRT-LLM package - unzip -q "$FILE" -d "$UNZIP_DIR" - echo "Unzip complete" -fi - - -export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" -echo ${TRTLLM_PLUGINS_PATH} ensure_mpi_installed libmpich-dev ensure_mpi_installed libopenmpi-dev @@ -123,7 +78,7 @@ run_tests() { cd .. export PYTHONPATH=$(pwd) echo "Running pytest on distributed/test_nccl_ops.py..." - pytest distributed/test_nccl_ops.py + USE_TRTLLM_PLUGINS=1 pytest distributed/test_nccl_ops.py } run_mpi_tests(){