Skip to content

Commit 79daffe

Browse files
committed
Qualcomm AI Engine Direct - enable operators adaptive_max_pool2d and grid_sampler
Enable operators adaptive_max_pool2d and grid_sampler 2D and 3D ```bash python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_adaptive_max_pool2d -b build-android -H $HOST -s $SN -m $CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_adaptive_max_pool2d -b build-android -H $HOST -s $SN -m $CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_grid_sampler -b build-android -H $HOST -s $SN -m $CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_grid_sampler -b build-android -H $HOST -s $SN -m $CHIPID ```
1 parent f4d801a commit 79daffe

File tree

11 files changed

+525
-3
lines changed

11 files changed

+525
-3
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ class LayoutTransform(ExportPass):
4242

4343
layout_sensitive_ops = {
4444
exir_ops.edge.aten.adaptive_avg_pool2d.default,
45+
exir_ops.edge.aten.adaptive_max_pool2d.default,
4546
exir_ops.edge.aten.avg_pool2d.default,
4647
exir_ops.edge.aten.convolution.default,
48+
exir_ops.edge.aten.grid_sampler_2d.default,
49+
exir_ops.edge.aten.grid_sampler_3d.default,
4750
exir_ops.edge.aten.instance_norm.default,
4851
exir_ops.edge.aten.max_pool2d_with_indices.default,
4952
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ Please help update following table if you are contributing new operators:
431431
| Gelu | ✓ |
432432
| GetSparseIndices | ✗ |
433433
| GetSparseValues | ✗ |
434-
| GridSample | ✗ |
434+
| GridSample | ✓ |
435435
| GroupNorm | ✓ |
436436
| HardSwish | ✓ |
437437
| InstanceNorm | ✓ |

backends/qualcomm/builders/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
node_visitor,
99
op_abs,
1010
op_adaptive_avg_pool2d,
11+
op_adaptive_max_pool2d,
1112
op_add,
1213
op_amax,
1314
op_amin,
@@ -43,6 +44,7 @@
4344
op_gather,
4445
op_ge,
4546
op_gelu,
47+
op_grid_sampler_2d,
4648
op_group_norm,
4749
op_gt,
4850
op_hardsigmoid,
@@ -113,6 +115,7 @@
113115
node_visitor,
114116
op_abs,
115117
op_adaptive_avg_pool2d,
118+
op_adaptive_max_pool2d,
116119
op_add,
117120
op_amax,
118121
op_amin,
@@ -148,6 +151,7 @@
148151
op_gather,
149152
op_ge,
150153
op_gelu,
154+
op_grid_sampler_2d,
151155
op_group_norm,
152156
op_gt,
153157
op_hardswish,
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import cast, Dict, List
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
import numpy as np
11+
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
17+
from .qnn_constants import OpPoolMax2d, QNN_OP_PACKAGE_NAME_QTI_AISW
18+
19+
20+
@register_node_visitor
21+
class AdaptiveMaxPool2D(NodeVisitor):
22+
target = ["aten.adaptive_max_pool2d.default"]
23+
24+
def __init__(self, *args) -> None:
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
31+
) -> PyQnnWrapper.PyQnnOpWrapper:
32+
input_node = self.get_node(node.args[0])
33+
input_tensor = self.get_tensor(input_node, node)
34+
input_tensor_wrapper = self.define_tensor(
35+
input_node,
36+
node,
37+
input_tensor,
38+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
39+
nodes_to_wrappers,
40+
)
41+
users = list(node.users.keys())
42+
for user in users:
43+
if user.target.__name__ == "getitem":
44+
getitem_index = user.args[1]
45+
if getitem_index != 0:
46+
warnings.warn(
47+
f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}",
48+
stacklevel=1,
49+
)
50+
return
51+
52+
if len(node.args) > 2:
53+
warnings.warn(
54+
"[QNN Delegate Op Builder]: The return_indices is not supported, fallback op",
55+
stacklevel=1,
56+
)
57+
return
58+
59+
input_height = input_tensor.shape[1]
60+
input_width = input_tensor.shape[2]
61+
# output cases
62+
out_wh = cast(List[int], node.args[1])
63+
if len(out_wh) == 1:
64+
output_height = node.args[1][0]
65+
output_width = node.args[1][0]
66+
else:
67+
output_height = node.args[1][0]
68+
output_width = node.args[1][1]
69+
if output_height is None:
70+
output_height = input_height
71+
if output_width is None:
72+
output_width = input_width
73+
# NOTE: Here we need not to emphasize on mode, cuz the output shape is decided by user.
74+
mode = OpPoolMax2d.RoundingMode.FLOOR
75+
76+
# floor division
77+
stride_height = input_height // output_height
78+
filter_height = input_height - (output_height - 1) * stride_height
79+
stride_width = input_width // output_width
80+
filter_width = input_width - (output_width - 1) * stride_width
81+
82+
filter = [filter_height, filter_width]
83+
filter_shape = [len(filter)]
84+
85+
stride = [stride_height, stride_width]
86+
stride_shape = [len(stride)]
87+
88+
padding = [0, 0]
89+
padding_shape = [len(padding), len(padding)]
90+
91+
out_tensor = self.get_tensor(node, node, 0)
92+
output_tensor_wrapper = self.define_tensor(
93+
node,
94+
node,
95+
out_tensor,
96+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
97+
nodes_to_wrappers,
98+
)
99+
100+
adaptive_max_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
101+
node.name,
102+
QNN_OP_PACKAGE_NAME_QTI_AISW,
103+
OpPoolMax2d.op_name,
104+
)
105+
106+
adaptive_max_pool2d_op.AddInputTensors([input_tensor_wrapper])
107+
adaptive_max_pool2d_op.AddOutputTensors([output_tensor_wrapper])
108+
109+
adaptive_max_pool2d_op.AddTensorParam(
110+
OpPoolMax2d.param_filter_size,
111+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
112+
len(filter_shape),
113+
filter_shape,
114+
np.array(
115+
filter,
116+
dtype=np.uint32,
117+
),
118+
True,
119+
)
120+
121+
adaptive_max_pool2d_op.AddTensorParam(
122+
OpPoolMax2d.param_stride,
123+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
124+
len(stride_shape),
125+
stride_shape,
126+
np.array(
127+
stride,
128+
dtype=np.uint32,
129+
),
130+
True,
131+
)
132+
133+
adaptive_max_pool2d_op.AddTensorParam(
134+
OpPoolMax2d.param_pad_amount,
135+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
136+
len(padding_shape),
137+
padding_shape,
138+
np.array(
139+
[[padding[0], padding[0]], [padding[1], padding[1]]],
140+
dtype=np.uint32,
141+
),
142+
True,
143+
)
144+
145+
adaptive_max_pool2d_op.AddScalarParam(
146+
OpPoolMax2d.param_rounding_mode,
147+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
148+
{QCOM_DATA: np.uint32(mode)},
149+
)
150+
151+
return adaptive_max_pool2d_op
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import cast, Dict, List
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
import numpy as np
11+
12+
import torch
13+
14+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_DTYPE
15+
16+
from .node_visitor import NodeVisitor, QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP
17+
from .node_visitor_manager import register_node_visitor
18+
from .qnn_constants import OpGridSample, OpTranspose, QNN_OP_PACKAGE_NAME_QTI_AISW
19+
20+
21+
@register_node_visitor
22+
class GridSample(NodeVisitor):
23+
target = ["aten.grid_sampler_2d.default", "aten.grid_sampler_3d.default"]
24+
25+
def __init__(self, *args) -> None:
26+
super().__init__(*args)
27+
28+
def define_node(
29+
self,
30+
node: torch.fx.Node,
31+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
32+
) -> PyQnnWrapper.PyQnnOpWrapper:
33+
grid_sample_op_list = []
34+
input_node = self.get_node(node.args[0])
35+
input_tensor = self.get_tensor(input_node, node)
36+
input_tensor_wrapper = self.define_tensor(
37+
input_node,
38+
node,
39+
input_tensor,
40+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
41+
nodes_to_wrappers,
42+
)
43+
44+
grid_node = self.get_node(node.args[1])
45+
grid_tensor = self.get_tensor(grid_node, node)
46+
grid_tensor_wrapper = self.define_tensor(
47+
grid_node,
48+
node,
49+
grid_tensor,
50+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
51+
nodes_to_wrappers,
52+
)
53+
54+
input_shape = input_node.meta["val"].shape
55+
input_rank = len(input_shape)
56+
if input_rank not in [4, 5]:
57+
warnings.warn(
58+
"[QNN Delegate Op Builder]: The shape is not supported, fallback op",
59+
stacklevel=1,
60+
)
61+
return
62+
63+
# About this operator, in ATen, the layout of input_tensor and of grid_tensor are not identical.
64+
# But in HW they are all NHWC or NDHWC. So, we make shape transformation again.
65+
if input_rank == 4:
66+
dims_shape_back = (0, 3, 1, 2)
67+
elif input_rank == 5:
68+
dims_shape_back = (0, 4, 1, 2, 3)
69+
else:
70+
warnings.warn(
71+
f"[QNN Delegate Op Builder]: Not support rank {input_rank}, fallback op",
72+
stacklevel=1,
73+
)
74+
return
75+
76+
grid_quant_encoding, grid_quant_configs = self.get_quant_encoding_conf(
77+
grid_node, node
78+
)
79+
grid_dtype = (
80+
QNN_TENSOR_TYPE_MAP[grid_tensor.dtype]
81+
if grid_quant_encoding
82+
== PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
83+
else QNN_QUANT_TYPE_MAP[
84+
(
85+
torch.uint16
86+
if grid_quant_configs[QCOM_DTYPE] == torch.int32
87+
else grid_quant_configs[QCOM_DTYPE]
88+
)
89+
]
90+
)
91+
# transpose
92+
permute_output_tensor = grid_tensor.permute(dims=dims_shape_back)
93+
transpose_output_tensor_wrapper = self.define_custom_tensor_wrapper(
94+
node_name=node.name + "_transpose",
95+
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
96+
dtype=grid_dtype,
97+
quant_encoding=grid_quant_encoding,
98+
quant_configs=grid_quant_configs,
99+
dims=permute_output_tensor.size(),
100+
tensor=permute_output_tensor,
101+
is_fake_tensor=True,
102+
nodes_to_wrappers=nodes_to_wrappers,
103+
)
104+
105+
permute_order = cast(List[int], dims_shape_back)
106+
permute_order_shape = [len(permute_order)]
107+
transpose_op = PyQnnWrapper.PyQnnOpWrapper(
108+
node.name,
109+
QNN_OP_PACKAGE_NAME_QTI_AISW,
110+
OpTranspose.op_name,
111+
)
112+
transpose_op.AddInputTensors([grid_tensor_wrapper])
113+
transpose_op.AddOutputTensors([transpose_output_tensor_wrapper])
114+
transpose_op.AddTensorParam(
115+
OpTranspose.param_perm,
116+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
117+
len(permute_order_shape),
118+
permute_order_shape,
119+
np.array(permute_order, dtype=np.uint32),
120+
True,
121+
)
122+
grid_sample_op_list.append(transpose_op)
123+
124+
out_tensor = self.get_tensor(node, node)
125+
output_tensor_wrapper = self.define_tensor(
126+
node,
127+
node,
128+
out_tensor,
129+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
130+
nodes_to_wrappers,
131+
)
132+
133+
align_corners = node.args[4] if len(node.args) > 4 else False
134+
padding_mode = node.args[3] if len(node.args) > 3 else 0
135+
interpo_mode = node.args[2] if len(node.args) > 2 else 0
136+
137+
grid_sample_op = PyQnnWrapper.PyQnnOpWrapper(
138+
node.name,
139+
QNN_OP_PACKAGE_NAME_QTI_AISW,
140+
OpGridSample.op_name,
141+
)
142+
grid_sample_op.AddInputTensors(
143+
[input_tensor_wrapper, transpose_output_tensor_wrapper]
144+
)
145+
grid_sample_op.AddOutputTensors([output_tensor_wrapper])
146+
grid_sample_op.AddScalarParam(
147+
OpGridSample.param_align_corners,
148+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
149+
{QCOM_DATA: align_corners},
150+
)
151+
grid_sample_op.AddScalarParam(
152+
OpGridSample.param_mode,
153+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
154+
{QCOM_DATA: np.uint32(interpo_mode)},
155+
)
156+
grid_sample_op.AddScalarParam(
157+
OpGridSample.param_padding_mode,
158+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
159+
{QCOM_DATA: np.uint32(padding_mode)},
160+
)
161+
grid_sample_op_list.append(grid_sample_op)
162+
return grid_sample_op_list

backends/qualcomm/builders/qnn_constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,24 @@ class OpGather:
304304
param_axis: str = "axis"
305305

306306

307+
class OpGridSample:
308+
op_name: str = "GridSample"
309+
param_align_corners: str = "align_corners"
310+
param_mode: str = "mode"
311+
param_padding_mode: str = "padding_mode"
312+
313+
@unique
314+
class Mode(IntEnum):
315+
BILINAR = 0
316+
NEAREST = 1
317+
318+
@unique
319+
class PaddingMode(IntEnum):
320+
ZEROS = 0
321+
BORDER = 1
322+
REFLECTION = 2
323+
324+
307325
@dataclass(init=False, frozen=True)
308326
class OpGatherElements:
309327
op_name: str = "GatherElements"

0 commit comments

Comments
 (0)