Skip to content

Commit 440cad9

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W linear ops support and test (#13448)
Summary: - Adds linear ops test using the 16A8W config in INT16 profile. - Adds support in view ops validation for INT16 Dtype. - Validated with TOSA pipeline test. - Checked earlier marked flaky tests no longer flaky and remove markers. Note: Not verified with tosa reference model run. Differential Revision: D80308822
1 parent b8f5123 commit 440cad9

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

backends/arm/operators/op_view.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
47+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
4848
output.tosa_spec,
4949
)
5050

backends/arm/test/ops/test_linear.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
import pytest
1212

1313
import torch
14-
from executorch.backends.arm.test import common
14+
from executorch.backends.arm.quantizer.arm_quantizer import (
15+
get_symmetric_a16w8_quantization_config,
16+
TOSAQuantizer,
17+
)
18+
from executorch.backends.arm.test import common, conftest
1519

1620
from executorch.backends.arm.test.tester.test_pipeline import (
1721
EthosU55PipelineINT,
@@ -20,6 +24,8 @@
2024
TosaPipelineINT,
2125
VgfPipeline,
2226
)
27+
from executorch.backends.arm.tosa_specification import TosaSpecification
28+
from executorch.backends.xnnpack.test.tester import Quantize
2329

2430
aten_op = "torch.ops.aten.linear.default"
2531

@@ -143,7 +149,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
143149
pipeline.run()
144150

145151

146-
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
147152
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
148153
def test_linear_tosa_INT(test_data: torch.Tensor):
149154
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -258,3 +263,52 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
258263
per_channel_quantization=per_channel_quantization,
259264
)
260265
pipeline.run()
266+
267+
def get_symmetric_a16w8_linear_quantizer(u55_config=False, per_channel_quantization=False):
268+
tosa_version = conftest.get_option("tosa_version")
269+
tosa_profiles = {
270+
"1.0": TosaSpecification.create_from_string(
271+
"TOSA-1.0+INT+int16"
272+
),
273+
}
274+
275+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
276+
quantizer.set_global(get_symmetric_a16w8_quantization_config(
277+
is_per_channel=per_channel_quantization
278+
))
279+
quantizer.set_module_type(
280+
torch.nn.Linear, get_symmetric_a16w8_quantization_config(
281+
is_per_channel=per_channel_quantization
282+
)
283+
)
284+
285+
return Quantize(quantizer, get_symmetric_a16w8_quantization_config(
286+
is_per_channel=per_channel_quantization
287+
))
288+
289+
@common.parametrize("test_data", test_data_rank1_INT, test_data_rank4_INT)
290+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
291+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
292+
test_data, out_features, has_bias, per_channel_quantization = test_data()
293+
in_features = test_data.shape[-1]
294+
295+
# Create pipeline with custom 16A8W quantization config
296+
pipeline = TosaPipelineINT[input_t1](
297+
Linear(
298+
in_features=in_features,
299+
out_features=out_features,
300+
bias=has_bias,
301+
),
302+
(test_data,),
303+
aten_op,
304+
exir_op=[],
305+
per_channel_quantization=per_channel_quantization,
306+
use_to_edge_transform_and_lower=True,
307+
tosa_extensions=["int16"],
308+
)
309+
310+
pipeline.change_args("quantize", get_symmetric_a16w8_linear_quantizer(
311+
per_channel_quantization=per_channel_quantization
312+
))
313+
# Run the pipeline
314+
pipeline.run()

0 commit comments

Comments
 (0)