55import getpass
66import logging
77import os
8+ import platform
89import tempfile
910import urllib .request
1011import warnings
2930from torch ._subclasses .fake_tensor import FakeTensor
3031from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
3132from torch_tensorrt ._Device import Device
32- from torch_tensorrt ._enums import Platform , dtype
33+ from torch_tensorrt ._enums import dtype
3334from torch_tensorrt ._features import ENABLED_FEATURES
3435from torch_tensorrt ._Input import Input
3536from torch_tensorrt ._version import __tensorrt_llm_version__
@@ -101,37 +102,6 @@ class Frameworks(Enum):
101102 }
102103
103104
104- def unified_dtype_converter (
105- dtype : Union [TRTDataType , torch .dtype , np .dtype ], to : Frameworks
106- ) -> Union [np .dtype , torch .dtype , TRTDataType ]:
107- """
108- Convert TensorRT, Numpy, or Torch data types to any other of those data types.
109-
110- Args:
111- dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
112- to (Frameworks): The framework to convert the data type to.
113-
114- Returns:
115- The equivalent data type in the requested framework.
116- """
117- assert to in Frameworks , f"Expected valid Framework for translation, got { to } "
118- trt_major_version = int (trt .__version__ .split ("." )[0 ])
119- if dtype in (np .int8 , torch .int8 , trt .int8 ):
120- return DataTypeEquivalence [trt .int8 ][to ]
121- elif trt_major_version >= 7 and dtype in (np .bool_ , torch .bool , trt .bool ):
122- return DataTypeEquivalence [trt .bool ][to ]
123- elif dtype in (np .int32 , torch .int32 , trt .int32 ):
124- return DataTypeEquivalence [trt .int32 ][to ]
125- elif dtype in (np .int64 , torch .int64 , trt .int64 ):
126- return DataTypeEquivalence [trt .int64 ][to ]
127- elif dtype in (np .float16 , torch .float16 , trt .float16 ):
128- return DataTypeEquivalence [trt .float16 ][to ]
129- elif dtype in (np .float32 , torch .float32 , trt .float32 ):
130- return DataTypeEquivalence [trt .float32 ][to ]
131- else :
132- raise TypeError ("%s is not a supported dtype" % dtype )
133-
134-
135105def deallocate_module (module : torch .fx .GraphModule , delete_module : bool = True ) -> None :
136106 """
137107 This is a helper function to delete the instance of module. We first move it to CPU and then
@@ -870,29 +840,33 @@ def is_tegra_platform() -> bool:
870840 return False
871841
872842
873- def is_platform_supported_for_trtllm (platform : str ) -> bool :
843+ def is_platform_supported_for_trtllm () -> bool :
874844 """
875- Checks if the current platform supports TensorRT-LLM plugins for NCCL backend
845+ Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend.
846+
876847 Returns:
877- bool: True if the platform supports TensorRT-LLM plugins for NCCL backend, False otherwise.
878- Note:
879- TensorRT-LLM plugins for NCCL backend are not supported on:
880- - Windows platforms
881- - Orin, Xavier, or Tegra devices (aarch64 architecture)
848+ bool: True if supported, False otherwise.
882849
850+ Unsupported:
851+ - Windows platforms
852+ - Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release)
883853 """
884- if "windows" in platform :
854+ system = platform .system ().lower ()
855+ machine = platform .machine ().lower ()
856+ release = platform .release ().lower ()
857+
858+ if "windows" in system :
885859 logger .info (
886- "TensorRT-LLM plugins for NCCL backend are not supported on Windows"
860+ "TensorRT-LLM plugins for NCCL backend are not supported on Windows. "
887861 )
888862 return False
889- if torch .cuda .is_available ():
890- device_name = torch .cuda .get_device_name ().lower ()
891- if any (keyword in device_name for keyword in ["orin" , "xavier" , "tegra" ]):
892- return False
863+
864+ if machine == "aarch64" and "tegra" in release :
893865 logger .info (
894- "TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices"
866+ "TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices. "
895867 )
868+ return False
869+
896870 return True
897871
898872
@@ -905,7 +879,7 @@ def _extracted_dir_trtllm(platform: str) -> Path:
905879 return _cache_root () / "trtllm" / f"{ __tensorrt_llm_version__ } _{ platform } "
906880
907881
908- def download_and_get_plugin_lib_path (platform : str ) -> Optional [str ]:
882+ def download_and_get_plugin_lib_path () -> Optional [str ]:
909883 """
910884 Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
911885
@@ -919,12 +893,13 @@ def download_and_get_plugin_lib_path(platform: str) -> Optional[str]:
919893 f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -"
920894 f"{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
921895 )
896+ platform_system = platform .system ().lower ()
922897 wheel_path = _cache_root () / wheel_filename
923- extract_dir = _extracted_dir_trtllm (platform )
898+ extract_dir = _extracted_dir_trtllm (platform_system )
924899 # else will never be met though
925900 lib_filename = (
926901 "libnvinfer_plugin_tensorrt_llm.so"
927- if "linux" in platform
902+ if "linux" in platform_system
928903 else "libnvinfer_plugin_tensorrt_llm.dll"
929904 )
930905 # eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so
@@ -1057,10 +1032,7 @@ def load_tensorrt_llm_for_nccl() -> bool:
10571032 Returns:
10581033 bool: True if the plugin was successfully loaded and initialized, False otherwise.
10591034 """
1060- # Check platform compatibility first
1061- platform = Platform .current_platform ()
1062- platform = str (platform ).lower ()
1063- if not is_platform_supported_for_trtllm (platform ):
1035+ if not is_platform_supported_for_trtllm ():
10641036 return False
10651037 plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
10661038
@@ -1080,6 +1052,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
10801052 )
10811053 return False
10821054
1083- plugin_lib_path = download_and_get_plugin_lib_path (platform )
1055+ plugin_lib_path = download_and_get_plugin_lib_path ()
10841056 return load_and_initialize_trtllm_plugin (plugin_lib_path ) # type: ignore[arg-type]
10851057 return False
0 commit comments