Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Collection, Dict, List, Optional, Tuple

import psutil
import torch
import torch.fx.passes.operator_support as ops
from torch.fx.node import Target
Expand Down Expand Up @@ -225,13 +226,80 @@ def partition_graph(self) -> torch.fx.GraphModule:
# Remove segments smaller than the block size (with exceptions)
subgraphs = self.remove_small_acc_subgraphs(subgraphs)

num_of_break = self.calculate_num_of_break(subgraphs)
subgraphs = self.break_subgraphs(subgraphs, num_of_break=num_of_break)

# Set the number of TRT engines to be generated
self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc])

# Tag the accelerated nodes and split the graph accordingly
self.tag(subgraphs)
return self.split()

def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this is too much of an heuristic based system. A better approach IMO is to calculate a graph size budget based on available memory (or eventually this could be user specified). Then of the TRT blocks we estimate its size and then decide how many subgraphs it should be split into to meet the budget

"""
This function calculates the break period based on the number of subgraphs.
"""
rss = psutil.Process().memory_info().rss
available_rss = psutil.virtual_memory().available
num_of_graphs = len(subgraphs)
if rss < available_rss * 0.3:
num_of_graphs = 1
elif rss < available_rss * 0.5:
num_of_graphs = 2
elif rss < available_rss:
num_of_graphs = 4
elif rss < available_rss * 1.5:
num_of_graphs = 8
elif rss < available_rss * 2:
num_of_graphs = 16
else:
num_of_graphs = 32

return max(
1, num_of_graphs // ((len(subgraphs) + 1) // 2)
) # If there are already graph breaks, for each TRT subgraph, we break for a few times.

def break_subgraphs(
self, subgraphs: List[Subgraph], num_of_break: int = 1
) -> List[Subgraph]:
"""
This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
"""

num_of_sdpa_node = len(
[node for node in self.acc_nodes if "scaled_dot" in str(node.target)]
)
break_period = num_of_sdpa_node // num_of_break + 1
current_break_idx = 0
current_num_break = 0
new_subgraphs = []
for subgraph in subgraphs:
if subgraph.is_acc:
for i, node in enumerate(subgraph.nodes):
if "scaled_dot" in str(node.target):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its fine if we do this for testing but really we should be taking a much more generic approach rather than assuming only sdpa is a viable break point.

current_num_break += 1
if current_num_break % break_period != 0:
continue
new_subgraphs.append(
Subgraph(
is_acc=True,
nodes=subgraph.nodes[current_break_idx : i + 1],
device_ordinal=subgraph.device_ordinal,
)
)
current_break_idx = i + 1
new_subgraphs.append(
Subgraph(
is_acc=True,
nodes=subgraph.nodes[current_break_idx:],
device_ordinal=subgraph.device_ordinal,
)
)
else:
new_subgraphs.append(subgraph)
return new_subgraphs

def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
"""Generates starter nodes for partitioning + segmentation"""
# Starter accelerated nodes are all callable accelerated ops
Expand Down
9 changes: 8 additions & 1 deletion tools/llm/torchtrt_ext/register_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
TORCH_TRT_DECOMPOSITIONS.pop(
torch.ops.aten._scaled_dot_product_efficient_attention.default, None
)
TORCH_TRT_DECOMPOSITIONS.pop(
torch.ops.aten._scaled_dot_product_cudnn_attention.default, None
)
TORCH_TRT_DECOMPOSITIONS.pop(
torch.ops.aten._scaled_dot_product_flash_attention.default, None
)

REPLACEABLE_ATEN_OPS = {
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
}

from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
Expand Down Expand Up @@ -68,7 +72,10 @@ def _process_sdpa_node(
ValueError: If the SDPA node has an unexpected number of arguments
"""

if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default:
if node.target in [
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
]:
if len(node.args) == 7:
(
query,
Expand Down