Skip to content

Commit 880b639

Browse files
committed
Revert windows solution. Not working
1 parent 66b40bd commit 880b639

File tree

3 files changed

+21
-43
lines changed

3 files changed

+21
-43
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def _save_weight_mapping(self) -> None:
596596
torch.cuda.empty_cache()
597597

598598
@needs_refit # type: ignore[misc]
599-
def _insert_engine_to_cache(self, hash_val: str, engine: bytes) -> None:
599+
def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None:
600600
serialized_engine = engine.serialize()
601601
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
602602
# if not self.compilation_settings.strip_engine_weights:
@@ -735,7 +735,7 @@ def run(
735735
return interpreter_result # type: ignore[no-any-return]
736736

737737
self._construct_trt_network_def()
738-
_LOGGER.info(
738+
_LOGGER.debug(
739739
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
740740
)
741741

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,15 @@
22

33
import io
44
import logging
5-
from typing import Any, List, Optional, Sequence
5+
from typing import Any, List, NamedTuple, Optional, Sequence
66

77
import torch
88
from torch_tensorrt._enums import dtype
99
from torch_tensorrt._features import ENABLED_FEATURES
1010
from torch_tensorrt._Input import Input
1111
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1212
from torch_tensorrt.dynamo._settings import CompilationSettings
13-
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
14-
TRTInterpreter,
15-
TRTInterpreterResult,
16-
)
13+
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
1714
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1815
from torch_tensorrt.dynamo.utils import (
1916
get_cpu_memory_usage,
@@ -24,6 +21,14 @@
2421
logger = logging.getLogger(__name__)
2522

2623

24+
class SerializedInterpreterResult(NamedTuple):
25+
serialized_engine: bytes
26+
input_names: Sequence[str]
27+
output_names: Sequence[str]
28+
weight_name_map: Optional[dict[Any, Any]]
29+
requires_output_allocator: bool
30+
31+
2732
def infer_module_output_dtypes(
2833
module: torch.fx.GraphModule,
2934
truncate_double: bool = False,
@@ -34,7 +39,7 @@ def infer_module_output_dtypes(
3439
"""
3540
outputs = [node for node in module.graph.nodes if node.op == "output"]
3641
outputs = outputs[0].args
37-
return get_output_dtypes(outputs, truncate_double)
42+
return get_output_dtypes(outputs, truncate_double) # type: ignore
3843

3944

4045
def interpret_module_to_result(
@@ -44,7 +49,7 @@ def interpret_module_to_result(
4449
arg_inputs: Optional[Sequence[Input]] = None,
4550
kwarg_inputs: Optional[dict[str, Any]] = None,
4651
engine_cache: Optional[BaseEngineCache] = None,
47-
) -> TRTInterpreterResult:
52+
) -> SerializedInterpreterResult:
4853
"""Interpret an FX module to a TRTInterpreterResult
4954
Args:
5055
module: FX GraphModule to interpret
@@ -84,16 +89,18 @@ def interpret_module_to_result(
8489
with io.BytesIO() as engine_bytes:
8590
engine_bytes.write(serialized_engine)
8691
serialized_engine = engine_bytes.getvalue()
87-
88-
interpreter_result = TRTInterpreterResult(
89-
engine=serialized_engine,
92+
logger.debug(
93+
f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB"
94+
)
95+
serialized_interpreter_result = SerializedInterpreterResult(
96+
serialized_engine=serialized_engine,
9097
input_names=interpreter_result.input_names,
9198
output_names=interpreter_result.output_names,
9299
weight_name_map=interpreter_result.weight_name_map,
93100
requires_output_allocator=interpreter_result.requires_output_allocator,
94101
)
95102

96-
return interpreter_result
103+
return serialized_interpreter_result
97104

98105

99106
def convert_module(
@@ -132,7 +139,7 @@ def convert_module(
132139
)
133140

134141
return rt_cls(
135-
serialized_engine=interpreter_result.engine,
142+
serialized_engine=interpreter_result.serialized_engine,
136143
input_binding_names=list(interpreter_result.input_names),
137144
output_binding_names=list(interpreter_result.output_names),
138145
name=name,

py/torch_tensorrt/dynamo/utils.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -882,32 +882,3 @@ def release_memory() -> None:
882882
logger.warning("Failed to release CPU memory.")
883883
except Exception:
884884
logger.warning("Failed to release CPU memory.")
885-
886-
elif platform.system() == "Windows":
887-
from ctypes import wintypes
888-
889-
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
890-
psapi = ctypes.WinDLL("psapi", use_last_error=True)
891-
892-
GetCurrentProcess = kernel32.GetCurrentProcess
893-
GetCurrentProcess.restype = wintypes.HANDLE
894-
hproc = GetCurrentProcess()
895-
896-
HeapSetInformation = kernel32.HeapSetInformation
897-
HeapSetInformation.argtypes = [
898-
wintypes.HANDLE,
899-
ctypes.c_int,
900-
ctypes.c_void_p,
901-
ctypes.c_size_t,
902-
]
903-
HeapSetInformation.restype = wintypes.BOOL
904-
GetProcessHeap = kernel32.GetProcessHeap
905-
GetProcessHeap.restype = wintypes.HANDLE
906-
ok = False
907-
try:
908-
HeapOptimizeResources = 3
909-
hheap = GetProcessHeap()
910-
if HeapSetInformation(hheap, HeapOptimizeResources, None, 0):
911-
ok = True
912-
except Exception:
913-
logger.warning("Failed to release CPU memory.")

0 commit comments

Comments
 (0)