Skip to content

Commit c19c1c1

Browse files
fix: update visualization and update transform metadata handling
Signed-off-by: Karthik Vetrivel <[email protected]>
1 parent 478b6b2 commit c19c1c1

File tree

1 file changed

+45
-69
lines changed

1 file changed

+45
-69
lines changed
Lines changed: 45 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Transformation to the graph to render nicely in model_explorer."""
22

3-
import json
43
from typing import Tuple
54

65
import torch
@@ -9,69 +8,13 @@
98

109
from ...models.factory import ModelFactory
1110
from ...shim.interface import CachedSequenceInterface
11+
from ...utils.logger import ad_logger
1212
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
1313

1414
try:
1515
import model_explorer
16-
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
17-
from model_explorer.pytorch_exported_program_adater_impl import (
18-
PytorchExportedProgramAdapterImpl,
19-
)
2016
except ImportError:
2117
model_explorer = None
22-
GraphNode = KeyValue = MetadataItem = PytorchExportedProgramAdapterImpl = None
23-
# Optionally, you can log a warning or handle this gracefully elsewhere
24-
25-
26-
def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
27-
shape = tensor.shape
28-
total_size = 1
29-
for dim in shape:
30-
total_size *= dim
31-
32-
if size_limit < 0 or size_limit >= total_size:
33-
return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist())
34-
35-
return json.dumps(
36-
(tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist()
37-
)
38-
39-
40-
def _get_shape(val):
41-
return json.dumps(
42-
list(
43-
map(
44-
lambda x: int(x) if str(x).isdigit() else str(x),
45-
val.shape,
46-
)
47-
)
48-
)
49-
50-
51-
def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
52-
out_vals = fx_node.meta.get("val")
53-
if out_vals is None:
54-
return
55-
56-
if isinstance(out_vals, (tuple, list)):
57-
for idx, val in enumerate(out_vals):
58-
metadata = MetadataItem(id=str(idx), attrs=[])
59-
if val is None:
60-
continue
61-
dtype = str(val.dtype)
62-
shape = _get_shape(val)
63-
metadata.attrs.append(KeyValue(key="tensor_shape", value=dtype + shape))
64-
node.outputsMetadata.append(metadata)
65-
elif isinstance(out_vals, torch.Tensor):
66-
dtype = str(out_vals.dtype)
67-
shape = _get_shape(out_vals)
68-
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value=dtype + shape)])
69-
node.outputsMetadata.append(metadata)
70-
elif isinstance(out_vals, bool):
71-
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value="bool[1]")])
72-
node.outputsMetadata.append(metadata)
73-
else:
74-
raise ValueError(f"Unsupported output type: {type(out_vals)}")
7518

7619

7720
# TODO(yudong): make custom_ops configurable
@@ -87,24 +30,57 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
8730

8831
@TransformRegistry.register("visualize_namespace")
8932
class VisualizeNamespace(BaseTransform):
33+
"""Transform to visualize the graph using Model Explorer.
34+
35+
This transform exports the graph module to an ExportedProgram and launches
36+
Model Explorer for interactive visualization. The visualization helps debug
37+
and understand the graph structure after AutoDeploy transformations.
38+
"""
39+
9040
def _apply(
9141
self,
9242
gm: GraphModule,
9343
cm: CachedSequenceInterface,
9444
factory: ModelFactory,
9545
shared_config: SharedConfig,
9646
) -> Tuple[GraphModule, TransformInfo]:
97-
PytorchExportedProgramAdapterImpl.print_tensor = print_tensor
98-
PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
47+
"""Export the graph and launch Model Explorer for visualization.
48+
49+
Args:
50+
gm: The graph module to visualize.
51+
cm: The cached sequence interface with input arguments.
52+
factory: The model factory (unused).
53+
shared_config: Shared configuration across transforms (unused).
54+
55+
Returns:
56+
A tuple of the unchanged graph module and transform info indicating
57+
whether visualization was successful or skipped.
58+
"""
59+
if model_explorer is None:
60+
return gm, TransformInfo(
61+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
62+
)
63+
64+
try:
65+
# Export graph module to ExportedProgram for visualization
66+
exported_program = te.export(gm, args=(), kwargs=cm.named_args, dynamic_shapes=None)
9967

100-
# TODO(yudong): make viz as non-block call.
101-
ep = te.export(gm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
102-
graph = ep.graph
103-
# Ensure the ops land up in the right module for better viz
104-
for n in graph.nodes:
105-
if n.target in CUSTOM_OPS:
106-
n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"]
68+
# Ensure custom ops appear nested under their parent modules
69+
for node in exported_program.graph.nodes:
70+
if node.target in CUSTOM_OPS:
71+
if node.args and hasattr(node.args[0], "meta"):
72+
node.meta["nn_module_stack"] = node.args[0].meta.get("nn_module_stack", {})
10773

108-
model_explorer.visualize_pytorch("model-viz", ep)
74+
ad_logger.info("Launching Model Explorer visualization...")
75+
model_explorer.visualize_pytorch("model-viz", exported_program)
10976

110-
return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
77+
return gm, TransformInfo(
78+
skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True
79+
)
80+
81+
except Exception as e:
82+
ad_logger.error(f"Failed to visualize graph with Model Explorer: {e}")
83+
# Don't fail the pipeline if visualization fails
84+
return gm, TransformInfo(
85+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
86+
)

0 commit comments

Comments
 (0)