Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 10 additions & 43 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class AvgPool2dVisitor(NodeVisitor):

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
Expand Down Expand Up @@ -105,43 +106,6 @@ def _build_generic_avgpool2d(
attr,
)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec
)

accumulator_type = ts.DType.INT32

input_qargs = get_input_qparams(node)
input_zp = input_qargs[0].get_zp_per_tensor()

output_qargs = get_output_qparams(node)
output_zp = output_qargs[0].get_zp_per_tensor()

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)


@register_node_visitor
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
target = "aten.avg_pool2d.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
Expand All @@ -159,14 +123,17 @@ def define_node(
)

if inputs[0].dtype == ts.DType.INT8:
super().define_node(node, tosa_graph, inputs, output)
accumulator_type = ts.DType.INT32
input_qargs = get_input_qparams(node)
input_zp = input_qargs[0].get_zp_per_tensor()

if inputs[0].dtype == ts.DType.FP32:
output_qargs = get_output_qparams(node)
output_zp = output_qargs[0].get_zp_per_tensor()
else:
accumulator_type = ts.DType.FP32
# Initilize zero point to zero.
input_zp = 0
output_zp = 0

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)
self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)
111 changes: 29 additions & 82 deletions backends/arm/operators/op_clamp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree
Expand Down Expand Up @@ -27,18 +26,19 @@


@register_node_visitor
class ClampVisitor_INT(NodeVisitor):
class ClampVisitor(NodeVisitor):
target = "aten.clamp.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)

def _get_min_max_arguments(
self, node: Node, dtype_min: int | float, dtype_max: int | float
self, node: Node, dtype: torch.dtype
) -> Tuple[int | float, int | float]:

def cast_type(value: Any) -> int | float:
Expand All @@ -48,6 +48,13 @@ def cast_type(value: Any) -> int | float:
# Attempt to cast to float
return float(value)

if dtype.is_floating_point:
dtype_min = torch.finfo(dtype).min
dtype_max = torch.finfo(dtype).max
else:
dtype_min = torch.iinfo(dtype).min
dtype_max = torch.iinfo(dtype).max

min_arg = dtype_min
max_arg = dtype_max

Expand All @@ -60,53 +67,15 @@ def cast_type(value: Any) -> int | float:

return min_arg, max_arg

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, [2, 3])
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target, [inputs[0], output], [ts.DType.INT8], output.tosa_spec
)

# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
min_int8, max_int8 = self._get_min_max_arguments(
node,
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max,
)

attr = ts.TosaSerializerAttribute()
attr.ClampAttribute(
np.frombuffer(np.int8(min_int8).tobytes(), dtype=np.uint8).tolist(),
np.frombuffer(np.int8(max_int8).tobytes(), dtype=np.uint8).tolist(),
ts.NanPropagationMode.PROPAGATE,
)

self._serialize_operator(
node,
tosa_graph,
ts.Op.CLAMP,
[inputs[0].name],
[output.name],
attr,
)


@register_node_visitor
class ClampVisitor_FP(ClampVisitor_INT):
# inheriting 'target' from INT class

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)
def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes:
if dtype == torch.float32:
return np.frombuffer(np.float32(value).tobytes(), dtype=np.uint8).tolist()
elif dtype == torch.float16:
return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist()
elif dtype == torch.int8:
return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist()
else:
raise ValueError(f"Unsupported dtype for to_bytes: {dtype}")

def define_node(
self,
Expand All @@ -120,42 +89,20 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.FP16, ts.DType.FP32],
[ts.DType.INT8, ts.DType.FP16, ts.DType.FP32],
output.tosa_spec,
)

node_input_dtype = node.meta["val"].dtype
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
min_val, max_val = self._get_min_max_arguments(node, node_input_dtype)

attr = ts.TosaSerializerAttribute()
match inputs[0].dtype:
case ts.DType.FP16:
min_f, max_f = self._get_min_max_arguments(
node,
torch.finfo(torch.float16).min,
torch.finfo(torch.float16).max,
)
min_bytes = np.frombuffer(
np.float16(min_f).tobytes(), dtype=np.uint8
).tolist()
max_bytes = np.frombuffer(
np.float16(max_f).tobytes(), dtype=np.uint8
).tolist()
case ts.DType.FP32:
min_f, max_f = self._get_min_max_arguments(
node,
torch.finfo(torch.float32).min,
torch.finfo(torch.float32).max,
)
min_bytes = np.frombuffer(
np.float32(min_f).tobytes(), dtype=np.uint8
).tolist()
max_bytes = np.frombuffer(
np.float32(max_f).tobytes(), dtype=np.uint8
).tolist()
case _:
raise RuntimeError(
f"Internal error: Unsupported dtype {inputs[0].dtype} in {self.target}"
)

attr.ClampAttribute(min_bytes, max_bytes, ts.NanPropagationMode.PROPAGATE)
attr.ClampAttribute(
self._to_bytes(min_val, node_input_dtype),
self._to_bytes(max_val, node_input_dtype),
nan_mode=ts.NanPropagationMode.PROPAGATE,
)

self._serialize_operator(
node,
Expand Down
66 changes: 15 additions & 51 deletions backends/arm/operators/op_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Sequence
from typing import Any, List

import tosa_serializer as ts

Expand All @@ -23,25 +23,36 @@


@register_node_visitor
class WhereVisitor_INT(NodeVisitor):
class WhereVisitor(NodeVisitor):
target = "aten.where.self"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)

def _add_node_to_tosa_graph(
def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
supported_dtypes: Sequence,
) -> None:

supported_dtypes = []
if output.tosa_spec.support_integer():
supported_dtypes += [
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
]
if output.tosa_spec.support_float():
supported_dtypes += [ts.DType.BOOL, ts.DType.FP16, ts.DType.FP32]

validate_num_inputs(self.target, inputs, 3)
# Not first input, which is condition tensor.
validate_same_dtype(self.target, inputs[1:], ts)
Expand All @@ -63,50 +74,3 @@ def _add_node_to_tosa_graph(
[output.name],
attr,
)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
bi_supported_dtypes = [
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.BOOL,
]
self._add_node_to_tosa_graph(
node, tosa_graph, inputs, output, bi_supported_dtypes
)


@register_node_visitor
class WhereVisitor_FP(WhereVisitor_INT):

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
mi_supported_dtypes = [
ts.DType.FP16,
ts.DType.FP32,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.BOOL,
]
self._add_node_to_tosa_graph(
node, tosa_graph, inputs, output, mi_supported_dtypes
)
Loading