Skip to content

Commit 9cc9f82

Browse files
authored
[ExecuTorch] Arm backend: Buckify cos test (#10505)
Still no reference model Differential Revision: [D73642290](https://our.internmc.facebook.com/intern/diff/D73642290/) ghstack-source-id: 280313776 Pull Request resolved: #10480
1 parent a8422d2 commit 9cc9f82

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

backends/arm/test/ops/test_cos.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
from typing import Tuple
88

9-
import torch
9+
import pytest
1010

11+
import torch
1112
from executorch.backends.arm.test import common, conftest
1213
from executorch.backends.arm.test.tester.test_pipeline import (
1314
EthosU55PipelineBI,
@@ -37,24 +38,28 @@ def forward(self, x: torch.Tensor):
3738

3839

3940
@common.parametrize("test_data", test_data_suite)
41+
@pytest.mark.tosa_ref_model
4042
def test_cos_tosa_MI(test_data: Tuple):
4143
pipeline = TosaPipelineMI[input_t1](
4244
Cos(),
4345
(test_data,),
4446
aten_op,
4547
exir_op=[],
48+
run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"),
4649
)
4750
if conftest.get_option("tosa_version") == "1.0":
4851
pipeline.run()
4952

5053

5154
@common.parametrize("test_data", test_data_suite)
55+
@pytest.mark.tosa_ref_model
5256
def test_cos_tosa_BI(test_data: Tuple):
5357
pipeline = TosaPipelineBI[input_t1](
5458
Cos(),
5559
(test_data,),
5660
aten_op,
5761
exir_op=[],
62+
run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"),
5863
)
5964
pipeline.run()
6065

backends/arm/test/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def define_arm_tests():
1717
"ops/test_slice.py",
1818
"ops/test_sigmoid.py",
1919
"ops/test_tanh.py",
20+
"ops/test_cos.py",
2021
]
2122

2223
# Quantization

backends/arm/test/tester/test_pipeline.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ class TosaPipelineBI(BasePipelineMaker, Generic[T]):
258258
exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
259259
if not using use_edge_to_transform_and_lower.
260260
261+
run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
262+
261263
tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
262264
options.
263265
use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -270,6 +272,7 @@ def __init__(
270272
test_data: T,
271273
aten_op: str | List[str],
272274
exir_op: Optional[str | List[str]] = None,
275+
run_on_tosa_ref_model: bool = True,
273276
tosa_version: str = "TOSA-0.80+BI",
274277
symmetric_io_quantization: bool = False,
275278
use_to_edge_transform_and_lower: bool = True,
@@ -324,13 +327,14 @@ def __init__(
324327
suffix="quant_nodes",
325328
)
326329

327-
self.add_stage(
328-
self.tester.run_method_and_compare_outputs,
329-
atol=atol,
330-
rtol=rtol,
331-
qtol=qtol,
332-
inputs=self.test_data,
333-
)
330+
if run_on_tosa_ref_model:
331+
self.add_stage(
332+
self.tester.run_method_and_compare_outputs,
333+
atol=atol,
334+
rtol=rtol,
335+
qtol=qtol,
336+
inputs=self.test_data,
337+
)
334338

335339

336340
class TosaPipelineMI(BasePipelineMaker, Generic[T]):
@@ -345,6 +349,8 @@ class TosaPipelineMI(BasePipelineMaker, Generic[T]):
345349
exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
346350
if not using use_edge_to_transform_and_lower.
347351
352+
run_on_tosa_ref_model: Set to true to test the tosa file on the TOSA reference model.
353+
348354
tosa_version: A string for identifying the TOSA version, see common.get_tosa_compile_spec for
349355
options.
350356
use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
@@ -357,6 +363,7 @@ def __init__(
357363
test_data: T,
358364
aten_op: str | List[str],
359365
exir_op: Optional[str | List[str]] = None,
366+
run_on_tosa_ref_model: bool = True,
360367
tosa_version: str = "TOSA-0.80+MI",
361368
use_to_edge_transform_and_lower: bool = True,
362369
custom_path: str = None,
@@ -385,13 +392,14 @@ def __init__(
385392
suffix="quant_nodes",
386393
)
387394

388-
self.add_stage(
389-
self.tester.run_method_and_compare_outputs,
390-
atol=atol,
391-
rtol=rtol,
392-
qtol=qtol,
393-
inputs=self.test_data,
394-
)
395+
if run_on_tosa_ref_model:
396+
self.add_stage(
397+
self.tester.run_method_and_compare_outputs,
398+
atol=atol,
399+
rtol=rtol,
400+
qtol=qtol,
401+
inputs=self.test_data,
402+
)
395403

396404

397405
class EthosU55PipelineBI(BasePipelineMaker, Generic[T]):

0 commit comments

Comments
 (0)