11"""Transformation to the graph to render nicely in model_explorer."""
22
3- import json
43from typing import Tuple
54
65import torch
98
109from ...models .factory import ModelFactory
1110from ...shim .interface import CachedSequenceInterface
11+ from ...utils .logger import ad_logger
1212from ..interface import BaseTransform , SharedConfig , TransformInfo , TransformRegistry
1313
1414try :
1515 import model_explorer
16- from model_explorer .graph_builder import GraphNode , KeyValue , MetadataItem
17- from model_explorer .pytorch_exported_program_adater_impl import (
18- PytorchExportedProgramAdapterImpl ,
19- )
2016except ImportError :
2117 model_explorer = None
22- GraphNode = KeyValue = MetadataItem = PytorchExportedProgramAdapterImpl = None
23- # Optionally, you can log a warning or handle this gracefully elsewhere
24-
25-
26- def print_tensor (self , tensor : torch .Tensor , size_limit : int = 16 ):
27- shape = tensor .shape
28- total_size = 1
29- for dim in shape :
30- total_size *= dim
31-
32- if size_limit < 0 or size_limit >= total_size :
33- return json .dumps (tensor .cpu ().detach ().to (torch .float32 ).numpy ().tolist ())
34-
35- return json .dumps (
36- (tensor .cpu ().detach ().to (torch .float32 ).numpy ().flatten ())[:size_limit ].tolist ()
37- )
38-
39-
40- def _get_shape (val ):
41- return json .dumps (
42- list (
43- map (
44- lambda x : int (x ) if str (x ).isdigit () else str (x ),
45- val .shape ,
46- )
47- )
48- )
49-
50-
51- def add_outputs_metadata (self , fx_node : torch .fx .node .Node , node : GraphNode ):
52- out_vals = fx_node .meta .get ("val" )
53- if out_vals is None :
54- return
55-
56- if isinstance (out_vals , (tuple , list )):
57- for idx , val in enumerate (out_vals ):
58- metadata = MetadataItem (id = str (idx ), attrs = [])
59- if val is None :
60- continue
61- dtype = str (val .dtype )
62- shape = _get_shape (val )
63- metadata .attrs .append (KeyValue (key = "tensor_shape" , value = dtype + shape ))
64- node .outputsMetadata .append (metadata )
65- elif isinstance (out_vals , torch .Tensor ):
66- dtype = str (out_vals .dtype )
67- shape = _get_shape (out_vals )
68- metadata = MetadataItem (id = "0" , attrs = [KeyValue (key = "tensor_shape" , value = dtype + shape )])
69- node .outputsMetadata .append (metadata )
70- elif isinstance (out_vals , bool ):
71- metadata = MetadataItem (id = "0" , attrs = [KeyValue (key = "tensor_shape" , value = "bool[1]" )])
72- node .outputsMetadata .append (metadata )
73- else :
74- raise ValueError (f"Unsupported output type: { type (out_vals )} " )
7518
7619
7720# TODO(yudong): make custom_ops configurable
@@ -87,24 +30,57 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
8730
8831@TransformRegistry .register ("visualize_namespace" )
8932class VisualizeNamespace (BaseTransform ):
33+ """Transform to visualize the graph using Model Explorer.
34+
35+ This transform exports the graph module to an ExportedProgram and launches
36+ Model Explorer for interactive visualization. The visualization helps debug
37+ and understand the graph structure after AutoDeploy transformations.
38+ """
39+
9040 def _apply (
9141 self ,
9242 gm : GraphModule ,
9343 cm : CachedSequenceInterface ,
9444 factory : ModelFactory ,
9545 shared_config : SharedConfig ,
9646 ) -> Tuple [GraphModule , TransformInfo ]:
97- PytorchExportedProgramAdapterImpl .print_tensor = print_tensor
98- PytorchExportedProgramAdapterImpl .add_outputs_metadata = add_outputs_metadata
47+ """Export the graph and launch Model Explorer for visualization.
48+
49+ Args:
50+ gm: The graph module to visualize.
51+ cm: The cached sequence interface with input arguments.
52+ factory: The model factory (unused).
53+ shared_config: Shared configuration across transforms (unused).
54+
55+ Returns:
56+ A tuple of the unchanged graph module and transform info indicating
57+ whether visualization was successful or skipped.
58+ """
59+ if model_explorer is None :
60+ return gm , TransformInfo (
61+ skipped = True , num_matches = 0 , is_clean = True , has_valid_shapes = True
62+ )
63+
64+ try :
65+ # Export graph module to ExportedProgram for visualization
66+ exported_program = te .export (gm , args = (), kwargs = cm .named_args , dynamic_shapes = None )
9967
100- # TODO(yudong): make viz as non-block call.
101- ep = te .export (gm , args = cm .args , dynamic_shapes = cm .dynamic_shapes )
102- graph = ep .graph
103- # Ensure the ops land up in the right module for better viz
104- for n in graph .nodes :
105- if n .target in CUSTOM_OPS :
106- n .meta ["nn_module_stack" ] = n .args [0 ].meta ["nn_module_stack" ]
68+ # Ensure custom ops appear nested under their parent modules
69+ for node in exported_program .graph .nodes :
70+ if node .target in CUSTOM_OPS :
71+ if node .args and hasattr (node .args [0 ], "meta" ):
72+ node .meta ["nn_module_stack" ] = node .args [0 ].meta .get ("nn_module_stack" , {})
10773
108- model_explorer .visualize_pytorch ("model-viz" , ep )
74+ ad_logger .info ("Launching Model Explorer visualization..." )
75+ model_explorer .visualize_pytorch ("model-viz" , exported_program )
10976
110- return gm , TransformInfo (skipped = False , num_matches = 1 , is_clean = True , has_valid_shapes = True )
77+ return gm , TransformInfo (
78+ skipped = False , num_matches = 1 , is_clean = True , has_valid_shapes = True
79+ )
80+
81+ except Exception as e :
82+ ad_logger .error (f"Failed to visualize graph with Model Explorer: { e } " )
83+ # Don't fail the pipeline if visualization fails
84+ return gm , TransformInfo (
85+ skipped = True , num_matches = 0 , is_clean = True , has_valid_shapes = True
86+ )
0 commit comments