@@ -102,6 +102,35 @@ class Frameworks(Enum):
102102 }
103103
104104
105+ def unified_dtype_converter (
106+ dtype : Union [TRTDataType , torch .dtype , np .dtype ], to : Frameworks
107+ ) -> Union [np .dtype , torch .dtype , TRTDataType ]:
108+ """
109+ Convert TensorRT, Numpy, or Torch data types to any other of those data types.
110+ Args:
111+ dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
112+ to (Frameworks): The framework to convert the data type to.
113+ Returns:
114+ The equivalent data type in the requested framework.
115+ """
116+ assert to in Frameworks , f"Expected valid Framework for translation, got { to } "
117+ trt_major_version = int (trt .__version__ .split ("." )[0 ])
118+ if dtype in (np .int8 , torch .int8 , trt .int8 ):
119+ return DataTypeEquivalence [trt .int8 ][to ]
120+ elif trt_major_version >= 7 and dtype in (np .bool_ , torch .bool , trt .bool ):
121+ return DataTypeEquivalence [trt .bool ][to ]
122+ elif dtype in (np .int32 , torch .int32 , trt .int32 ):
123+ return DataTypeEquivalence [trt .int32 ][to ]
124+ elif dtype in (np .int64 , torch .int64 , trt .int64 ):
125+ return DataTypeEquivalence [trt .int64 ][to ]
126+ elif dtype in (np .float16 , torch .float16 , trt .float16 ):
127+ return DataTypeEquivalence [trt .float16 ][to ]
128+ elif dtype in (np .float32 , torch .float32 , trt .float32 ):
129+ return DataTypeEquivalence [trt .float32 ][to ]
130+ else :
131+ raise TypeError ("%s is not a supported dtype" % dtype )
132+
133+
105134def deallocate_module (module : torch .fx .GraphModule , delete_module : bool = True ) -> None :
106135 """
107136 This is a helper function to delete the instance of module. We first move it to CPU and then
@@ -875,8 +904,12 @@ def _cache_root() -> Path:
875904 return Path (tempfile .gettempdir ()) / f"torch_tensorrt_{ username } "
876905
877906
878- def _extracted_dir_trtllm (platform : str ) -> Path :
879- return _cache_root () / "trtllm" / f"{ __tensorrt_llm_version__ } _{ platform } "
907+ def _extracted_dir_trtllm (platform_system : str , platform_machine : str ) -> Path :
908+ return (
909+ _cache_root ()
910+ / "trtllm"
911+ / f"{ __tensorrt_llm_version__ } _{ platform_system } _{ platform_machine } "
912+ )
880913
881914
882915def download_and_get_plugin_lib_path () -> Optional [str ]:
@@ -889,13 +922,14 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
889922 Returns:
890923 Optional[str]: Path to shared library or None if operation fails.
891924 """
925+ platform_system = platform .system ().lower ()
926+ platform_machine = platform .machine ().lower ()
892927 wheel_filename = (
893928 f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -"
894- f"{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
929+ f"{ _WHL_CPYTHON_VERSION } -{ platform_system } _ { platform_machine } .whl"
895930 )
896- platform_system = platform .system ().lower ()
897931 wheel_path = _cache_root () / wheel_filename
898- extract_dir = _extracted_dir_trtllm (platform_system )
932+ extract_dir = _extracted_dir_trtllm (platform_system , platform_machine )
899933 # else will never be met though
900934 lib_filename = (
901935 "libnvinfer_plugin_tensorrt_llm.so"
0 commit comments