Skip to content

Commit fddc075

Browse files
committed
Added engine back-to-back break for CPU memory optimization
1 parent 880b639 commit fddc075

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import Collection, Dict, List, Optional, Tuple
33

4+
import psutil
45
import torch
56
import torch.fx.passes.operator_support as ops
67
from torch.fx.node import Target
@@ -225,13 +226,80 @@ def partition_graph(self) -> torch.fx.GraphModule:
225226
# Remove segments smaller than the block size (with exceptions)
226227
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
227228

229+
num_of_break = self.calculate_num_of_break(subgraphs)
230+
subgraphs = self.break_subgraphs(subgraphs, num_of_break=num_of_break)
231+
228232
# Set the number of TRT engines to be generated
229233
self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc])
230234

231235
# Tag the accelerated nodes and split the graph accordingly
232236
self.tag(subgraphs)
233237
return self.split()
234238

239+
def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int:
240+
"""
241+
This function calculates the break period based on the number of subgraphs.
242+
"""
243+
rss = psutil.Process().memory_info().rss
244+
available_rss = psutil.virtual_memory().available
245+
num_of_graphs = len(subgraphs)
246+
if rss < available_rss * 0.3:
247+
num_of_graphs = 1
248+
elif rss < available_rss * 0.5:
249+
num_of_graphs = 2
250+
elif rss < available_rss:
251+
num_of_graphs = 4
252+
elif rss < available_rss * 1.5:
253+
num_of_graphs = 8
254+
elif rss < available_rss * 2:
255+
num_of_graphs = 16
256+
else:
257+
num_of_graphs = 32
258+
259+
return max(
260+
1, num_of_graphs // ((len(subgraphs) + 1) // 2)
261+
) # If there are already graph breaks, for each TRT subgraph, we break for a few times.
262+
263+
def break_subgraphs(
264+
self, subgraphs: List[Subgraph], num_of_break: int = 1
265+
) -> List[Subgraph]:
266+
"""
267+
This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
268+
"""
269+
270+
num_of_sdpa_node = len(
271+
[node for node in self.acc_nodes if "scaled_dot" in str(node.target)]
272+
)
273+
break_period = num_of_sdpa_node // num_of_break + 1
274+
current_break_idx = 0
275+
current_num_break = 0
276+
new_subgraphs = []
277+
for subgraph in subgraphs:
278+
if subgraph.is_acc:
279+
for i, node in enumerate(subgraph.nodes):
280+
if "scaled_dot" in str(node.target):
281+
current_num_break += 1
282+
if current_num_break % break_period != 0:
283+
continue
284+
new_subgraphs.append(
285+
Subgraph(
286+
is_acc=True,
287+
nodes=subgraph.nodes[current_break_idx : i + 1],
288+
device_ordinal=subgraph.device_ordinal,
289+
)
290+
)
291+
current_break_idx = i + 1
292+
new_subgraphs.append(
293+
Subgraph(
294+
is_acc=True,
295+
nodes=subgraph.nodes[current_break_idx:],
296+
device_ordinal=subgraph.device_ordinal,
297+
)
298+
)
299+
else:
300+
new_subgraphs.append(subgraph)
301+
return new_subgraphs
302+
235303
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
236304
"""Generates starter nodes for partitioning + segmentation"""
237305
# Starter accelerated nodes are all callable accelerated ops

tools/llm/torchtrt_ext/register_sdpa.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,17 @@
2525
TORCH_TRT_DECOMPOSITIONS.pop(
2626
torch.ops.aten._scaled_dot_product_efficient_attention.default, None
2727
)
28+
TORCH_TRT_DECOMPOSITIONS.pop(
29+
torch.ops.aten._scaled_dot_product_cudnn_attention.default, None
30+
)
2831
TORCH_TRT_DECOMPOSITIONS.pop(
2932
torch.ops.aten._scaled_dot_product_flash_attention.default, None
3033
)
3134

3235
REPLACEABLE_ATEN_OPS = {
3336
torch.ops.aten._scaled_dot_product_efficient_attention.default,
3437
torch.ops.aten._scaled_dot_product_flash_attention.default,
38+
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
3539
}
3640

3741
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
@@ -68,7 +72,10 @@ def _process_sdpa_node(
6872
ValueError: If the SDPA node has an unexpected number of arguments
6973
"""
7074

71-
if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default:
75+
if node.target in [
76+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
77+
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
78+
]:
7279
if len(node.args) == 7:
7380
(
7481
query,

0 commit comments

Comments
 (0)