@@ -258,6 +258,8 @@ class TosaPipelineBI(BasePipelineMaker, Generic[T]):
258
258
exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
259
259
if not using use_edge_to_transform_and_lower.
260
260
261
+ run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
262
+
261
263
tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
262
264
options.
263
265
use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -270,6 +272,7 @@ def __init__(
270
272
test_data : T ,
271
273
aten_op : str | List [str ],
272
274
exir_op : Optional [str | List [str ]] = None ,
275
+ run_on_tosa_ref_model : bool = True ,
273
276
tosa_version : str = "TOSA-0.80+BI" ,
274
277
symmetric_io_quantization : bool = False ,
275
278
use_to_edge_transform_and_lower : bool = True ,
@@ -324,13 +327,14 @@ def __init__(
324
327
suffix = "quant_nodes" ,
325
328
)
326
329
327
- self .add_stage (
328
- self .tester .run_method_and_compare_outputs ,
329
- atol = atol ,
330
- rtol = rtol ,
331
- qtol = qtol ,
332
- inputs = self .test_data ,
333
- )
330
+ if run_on_tosa_ref_model :
331
+ self .add_stage (
332
+ self .tester .run_method_and_compare_outputs ,
333
+ atol = atol ,
334
+ rtol = rtol ,
335
+ qtol = qtol ,
336
+ inputs = self .test_data ,
337
+ )
334
338
335
339
336
340
class TosaPipelineMI (BasePipelineMaker , Generic [T ]):
@@ -345,6 +349,8 @@ class TosaPipelineMI(BasePipelineMaker, Generic[T]):
345
349
exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
346
350
if not using use_edge_to_transform_and_lower.
347
351
352
+ run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
353
+
348
354
tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
349
355
options.
350
356
use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -357,6 +363,7 @@ def __init__(
357
363
test_data : T ,
358
364
aten_op : str | List [str ],
359
365
exir_op : Optional [str | List [str ]] = None ,
366
+ run_on_tosa_ref_model : bool = True ,
360
367
tosa_version : str = "TOSA-0.80+MI" ,
361
368
use_to_edge_transform_and_lower : bool = True ,
362
369
custom_path : str = None ,
@@ -385,13 +392,14 @@ def __init__(
385
392
suffix = "quant_nodes" ,
386
393
)
387
394
388
- self .add_stage (
389
- self .tester .run_method_and_compare_outputs ,
390
- atol = atol ,
391
- rtol = rtol ,
392
- qtol = qtol ,
393
- inputs = self .test_data ,
394
- )
395
+ if run_on_tosa_ref_model :
396
+ self .add_stage (
397
+ self .tester .run_method_and_compare_outputs ,
398
+ atol = atol ,
399
+ rtol = rtol ,
400
+ qtol = qtol ,
401
+ inputs = self .test_data ,
402
+ )
395
403
396
404
397
405
class EthosU55PipelineBI (BasePipelineMaker , Generic [T ]):
0 commit comments