@@ -306,31 +306,6 @@ def get_args_and_kwargs_conv(
306306
307307 (out_multiplier , out_shift ) = quantize_tensor_multiplier (requantize_scale_t )
308308
309- out_multiplier_ = graph_module .graph .call_function (
310- torch .ops .aten .full .default ,
311- ([1 ], out_multiplier [0 ].item ()),
312- {"dtype" : torch .int32 },
313- )
314- out_shift_ = graph_module .graph .call_function (
315- torch .ops .aten .full .default ,
316- ([1 ], out_shift [0 ].item ()),
317- {"dtype" : torch .int32 },
318- )
319-
320- # Create a single element tensor for the weight zero point
321- weight_zero_point_tensor = graph_module .graph .call_function (
322- torch .ops .aten .full .default ,
323- ([1 ], weight_zero_point ),
324- {"dtype" : torch .int32 },
325- )
326-
327- # Create a single element tensor for the bias scale
328- bias_scale_tensor = graph_module .graph .call_function (
329- torch .ops .aten .full .default ,
330- ([1 ], bias_scale ),
331- {"dtype" : torch .float32 },
332- )
333-
334309 # Make the args and kwargs for the replacement op
335310 args = tuple (inputs_inputs + weights_inputs + [bias ])
336311 kwargs = {
@@ -339,12 +314,12 @@ def get_args_and_kwargs_conv(
339314 "dilation" : dilation ,
340315 "groups" : groups ,
341316 "input_zero_point" : dequants_inputs [0 ].args [2 ],
342- "weight_zero_point" : weight_zero_point_tensor ,
343- "bias_scale" : bias_scale_tensor ,
317+ "weight_zero_point" : weight_zero_point ,
318+ "bias_scale" : bias_scale ,
344319 "out_scale" : quant_node .args [1 ],
345320 "out_zero_point" : quant_node .args [2 ],
346- "out_multiplier" : out_multiplier_ ,
347- "out_shift" : out_shift_ ,
321+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
322+ "out_shift" : out_shift [ 0 ]. item () ,
348323 }
349324 return args , kwargs
350325
@@ -365,27 +340,11 @@ def get_args_and_kwargs_relu(
365340 # Make the args and kwargs for the replacement op
366341 args = tuple (inputs_inputs )
367342
368- X_zero_point = graph_module .graph .call_function (
369- torch .ops .aten .full .default ,
370- ([1 ], dequants_inputs [0 ].args [2 ]),
371- {"dtype" : torch .int32 },
372- )
373- out_multiplier_ = graph_module .graph .call_function (
374- torch .ops .aten .full .default ,
375- ([1 ], out_multiplier [0 ].item ()),
376- {"dtype" : torch .int32 },
377- )
378- out_shift_ = graph_module .graph .call_function (
379- torch .ops .aten .full .default ,
380- ([1 ], out_shift [0 ].item ()),
381- {"dtype" : torch .int32 },
382- )
383-
384343 kwargs = {
385- "X_zero_point" : X_zero_point ,
344+ "X_zero_point" : dequants_inputs [ 0 ]. args [ 2 ] ,
386345 "out_zero_point" : quant_node .args [2 ],
387- "out_multiplier" : out_multiplier_ ,
388- "out_shift" : out_shift_ ,
346+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
347+ "out_shift" : out_shift [ 0 ]. item () ,
389348 }
390349 return args , kwargs
391350
0 commit comments