Skip to content

Commit bb0518c

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W support and test for add operation (#13568)
Summary: Add 16A8W quantization support and test for the add operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to add operations. Changes: - Add INT16 dtype validation support in op_add.py - Add test_add_tensor_16a8w_tosa_INT test function - Enable test_add.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: D80510463
1 parent 440cad9 commit bb0518c

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

backends/arm/operators/op_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def define_node(
5050
validate_valid_dtype(
5151
self.target,
5252
[*inputs, output],
53-
[ts.DType.INT8, ts.DType.INT32],
53+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
5454
output.tosa_spec,
5555
)
5656

backends/arm/test/ops/test_add.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import pytest
1111
import torch
1212
from executorch.backends.arm.quantizer import arm_quantizer
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
1317
from executorch.backends.arm.test import common, conftest
1418
from executorch.backends.arm.test.tester.test_pipeline import (
1519
EthosU55PipelineINT,
@@ -216,3 +220,46 @@ def test_add_tensor_vgf_INT(test_data: input_t1):
216220
tosa_version="TOSA-1.0+INT",
217221
)
218222
pipeline.run()
223+
224+
225+
def get_symmetric_a16w8_add_quantizer(u55_config=False, per_channel_quantization=False):
226+
tosa_version = conftest.get_option("tosa_version")
227+
tosa_profiles = {
228+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
229+
}
230+
231+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
232+
quantizer.set_global(
233+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
234+
)
235+
236+
return Quantize(
237+
quantizer,
238+
get_symmetric_a16w8_quantization_config(
239+
is_per_channel=per_channel_quantization
240+
),
241+
)
242+
243+
244+
@common.parametrize("test_data", Add.test_data)
245+
def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
246+
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
247+
per_channel_quantization = False
248+
249+
pipeline = TosaPipelineINT[input_t1](
250+
Add(),
251+
test_data(),
252+
aten_op,
253+
exir_op=[],
254+
per_channel_quantization=per_channel_quantization,
255+
use_to_edge_transform_and_lower=True,
256+
tosa_extensions=["int16"],
257+
)
258+
259+
pipeline.change_args(
260+
"quantize",
261+
get_symmetric_a16w8_add_quantizer(
262+
per_channel_quantization=per_channel_quantization
263+
),
264+
)
265+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def define_arm_tests():
1313

1414
# Operators
1515
test_files += [
16+
"ops/test_add.py",
1617
"ops/test_avg_pool2d.py",
1718
"ops/test_linear.py",
1819
"ops/test_slice.py",

0 commit comments

Comments
 (0)