Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 45 additions & 69 deletions tensorrt_llm/_torch/auto_deploy/transform/library/visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Transformation to the graph to render nicely in model_explorer."""

import json
from typing import Tuple

import torch
Expand All @@ -9,69 +8,13 @@

from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry

try:
import model_explorer
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
from model_explorer.pytorch_exported_program_adater_impl import (
PytorchExportedProgramAdapterImpl,
)
except ImportError:
model_explorer = None
GraphNode = KeyValue = MetadataItem = PytorchExportedProgramAdapterImpl = None
# Optionally, you can log a warning or handle this gracefully elsewhere


def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
shape = tensor.shape
total_size = 1
for dim in shape:
total_size *= dim

if size_limit < 0 or size_limit >= total_size:
return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist())

return json.dumps(
(tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist()
)


def _get_shape(val):
return json.dumps(
list(
map(
lambda x: int(x) if str(x).isdigit() else str(x),
val.shape,
)
)
)


def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
out_vals = fx_node.meta.get("val")
if out_vals is None:
return

if isinstance(out_vals, (tuple, list)):
for idx, val in enumerate(out_vals):
metadata = MetadataItem(id=str(idx), attrs=[])
if val is None:
continue
dtype = str(val.dtype)
shape = _get_shape(val)
metadata.attrs.append(KeyValue(key="tensor_shape", value=dtype + shape))
node.outputsMetadata.append(metadata)
elif isinstance(out_vals, torch.Tensor):
dtype = str(out_vals.dtype)
shape = _get_shape(out_vals)
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value=dtype + shape)])
node.outputsMetadata.append(metadata)
elif isinstance(out_vals, bool):
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value="bool[1]")])
node.outputsMetadata.append(metadata)
else:
raise ValueError(f"Unsupported output type: {type(out_vals)}")


# TODO(yudong): make custom_ops configurable
Expand All @@ -87,24 +30,57 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):

@TransformRegistry.register("visualize_namespace")
class VisualizeNamespace(BaseTransform):
"""Transform to visualize the graph using Model Explorer.

This transform exports the graph module to an ExportedProgram and launches
Model Explorer for interactive visualization. The visualization helps debug
and understand the graph structure after AutoDeploy transformations.
"""

def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
PytorchExportedProgramAdapterImpl.print_tensor = print_tensor
PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
"""Export the graph and launch Model Explorer for visualization.

Args:
gm: The graph module to visualize.
cm: The cached sequence interface with input arguments.
factory: The model factory (unused).
shared_config: Shared configuration across transforms (unused).

Returns:
A tuple of the unchanged graph module and transform info indicating
whether visualization was successful or skipped.
"""
if model_explorer is None:
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)

try:
# Export graph module to ExportedProgram for visualization
exported_program = te.export(gm, args=(), kwargs=cm.named_args, dynamic_shapes=None)

# TODO(yudong): make viz as non-block call.
ep = te.export(gm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
graph = ep.graph
# Ensure the ops land up in the right module for better viz
for n in graph.nodes:
if n.target in CUSTOM_OPS:
n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"]
# Ensure custom ops appear nested under their parent modules
for node in exported_program.graph.nodes:
if node.target in CUSTOM_OPS:
if node.args and hasattr(node.args[0], "meta"):
node.meta["nn_module_stack"] = node.args[0].meta.get("nn_module_stack", {})

model_explorer.visualize_pytorch("model-viz", ep)
ad_logger.info("Launching Model Explorer visualization...")
model_explorer.visualize_pytorch("model-viz", exported_program)

return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
return gm, TransformInfo(
skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True
)

except Exception as e:
ad_logger.error(f"Failed to visualize graph with Model Explorer: {e}")
# Don't fail the pipeline if visualization fails
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)