@@ -193,39 +193,40 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
193193
194194
195195@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
196- class ReplaceSqueezeAndUnsqueezeWithViewPass (ExportPass ):
196+ class ReplaceSqueezeAndUnsqueezeWithViewPass (RemoveOrReplacePassInterface ):
197197 """
198198 When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
199199 view_copy op
200200 """
201201
202- def call_operator (
203- self ,
204- op ,
205- args : Tuple [Argument , ...],
206- kwargs : Dict [str , Argument ],
207- meta : NodeMetadata ,
208- ) -> ProxyValue :
209- # Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket,
210- # which allows us to cover all overloads.
211- if get_edge_overload_packet (op ) not in {
212- exir_ops .edge .aten .squeeze_copy ,
213- exir_ops .edge .aten .unsqueeze_copy ,
214- }:
215- return super ().call_operator (op , args , kwargs , meta )
202+ @property
203+ def targets (self ) -> list [EdgeOpOverload ]:
204+ return [
205+ exir_ops .edge .aten .squeeze_copy .default ,
206+ exir_ops .edge .aten .squeeze_copy .dim ,
207+ exir_ops .edge .aten .squeeze_copy .dims ,
208+ exir_ops .edge .aten .unsqueeze_copy .default ,
209+ ]
210+
211+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
216212 # Get the output tensor shape
217- out_shape = meta ["val" ].shape
213+ out_shape = node . meta ["val" ].shape
218214
219215 # Bail out if any dim is not an int (dynamic shape)
220216 for dim in list (out_shape ):
221217 if not isinstance (dim , int ):
222- return super (). call_operator ( op , args , kwargs , meta )
218+ return False
223219
224- # Return a view op with the new shape
225- view_args = (args [0 ], list (out_shape ))
226- return super ().call_operator (
227- exir_ops .edge .aten .view_copy .default , view_args , kwargs , meta
228- )
220+ # Replace with view op with the new shape
221+ with node .graph .inserting_before (node ):
222+ new_node = node .graph .call_function (
223+ exir_ops .edge .aten .view_copy .default ,
224+ args = (node .args [0 ], list (out_shape )),
225+ )
226+ # Do not remove the metadata copy!
227+ new_node .meta = node .meta
228+ node .replace_all_uses_with (new_node )
229+ return True
229230
230231
231232@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
0 commit comments