Skip to content

Commit 0b2f656

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Removed support for non-per-tensor quantized relu (#14788)
Summary: Not supporting quantized relu default, so removing it from ref_implementations Reviewed By: zonglinpeng Differential Revision: D83874866
1 parent f443ebb commit 0b2f656

File tree

2 files changed

+31
-82
lines changed

2 files changed

+31
-82
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,75 +1125,46 @@ def quantized_relu_common(
11251125

11261126

11271127
def quantized_relu_variant(
1128-
per_tensor: bool,
11291128
dtype: torch.dtype | None = None,
11301129
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
11311130
"""Create a quantized relu variant with type checking."""
11321131

11331132
def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
11341133
def variant(
11351134
X: torch.Tensor,
1136-
X_zero_point: torch.Tensor | int,
1135+
X_zero_point: int,
11371136
out_zero_point: int,
1138-
out_multiplier: torch.Tensor | int,
1139-
out_shift: torch.Tensor | int,
1137+
out_multiplier: int,
1138+
out_shift: int,
11401139
) -> torch.Tensor:
1141-
if per_tensor:
1142-
if dtype and X.dtype != dtype:
1143-
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")
1144-
1145-
assert isinstance(out_shift, int)
1146-
assert isinstance(out_multiplier, int)
1147-
_out_shift = out_shift
1148-
_out_multiplier = out_multiplier
1149-
else:
1150-
assert isinstance(out_multiplier, torch.Tensor)
1151-
if out_multiplier.numel() > 1:
1152-
raise ValueError("Only scalar out_multiplier is supported")
1153-
1154-
assert isinstance(out_shift, torch.Tensor)
1155-
if out_shift.numel() > 1:
1156-
raise ValueError("Only scalar out_shift is supported")
1157-
1158-
assert isinstance(X_zero_point, torch.Tensor)
1159-
if X_zero_point.shape != X.shape:
1160-
raise ValueError(
1161-
f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}"
1162-
)
1163-
1164-
_out_multiplier = int(out_multiplier.item())
1165-
_out_shift = int(out_shift.item())
1140+
if dtype and X.dtype != dtype:
1141+
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")
11661142

11671143
return quantized_relu_common(
11681144
X,
11691145
X_zero_point,
11701146
out_zero_point,
1171-
_out_multiplier,
1172-
_out_shift,
1147+
out_multiplier,
1148+
out_shift,
11731149
)
11741150

11751151
return variant
11761152

11771153
return decorator
11781154

11791155

1180-
@impl(m, "quantized_relu")
1181-
@quantized_relu_variant(False)
1182-
def quantized_relu() -> torch.Tensor: ...
1183-
1184-
11851156
@impl(m, "quantized_relu.per_tensor")
1186-
@quantized_relu_variant(True)
1157+
@quantized_relu_variant()
11871158
def quantized_relu_per_tensor() -> torch.Tensor: ...
11881159

11891160

11901161
@impl(m, "quantized_relu_asym8s_asym8s.per_tensor")
1191-
@quantized_relu_variant(True, torch.int8)
1162+
@quantized_relu_variant(torch.int8)
11921163
def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...
11931164

11941165

11951166
@impl(m, "quantized_relu_asym8u_asym8u.per_tensor")
1196-
@quantized_relu_variant(True, torch.uint8)
1167+
@quantized_relu_variant(torch.uint8)
11971168
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...
11981169

11991170

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,61 +1080,39 @@ def test_quantized_conv_per_tensor(
10801080
)
10811081
for dtype in [torch.uint8]
10821082
],
1083-
# Test case 4: Non-per-tensor
1084-
*[
1085-
(
1086-
"non_per_tensor",
1087-
torch.tensor([-1, -2, -3, 1, 2, 3], dtype=dtype), # input
1088-
torch.tensor([0, 0, 0, 1, 1, 1]), # X_zero_point
1089-
5, # out_zero_point
1090-
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
1091-
torch.tensor([1]), # out_shift (multiply by 2^1 = 2)
1092-
dtype, # dtype
1093-
torch.tensor([5, 5, 5, 5, 4, 3], dtype=dtype),
1094-
)
1095-
for dtype in [torch.int8]
1096-
],
10971083
]
10981084
)
10991085
def test_quantized_relu(
11001086
self,
11011087
name: str,
11021088
X: torch.Tensor,
1103-
X_zero_point: torch.Tensor | int,
1089+
X_zero_point: int,
11041090
out_zero_point: int,
1105-
out_multiplier: torch.Tensor | int,
1106-
out_shift: torch.Tensor | int,
1091+
out_multiplier: int,
1092+
out_shift: int,
11071093
dtype: torch.dtype,
11081094
expected_output: torch.Tensor,
11091095
) -> None:
11101096

1111-
if isinstance(X_zero_point, int):
1112-
assert isinstance(out_multiplier, int)
1113-
assert isinstance(out_shift, int)
1114-
1115-
match dtype:
1116-
case torch.int8:
1117-
quantized_relu = (
1118-
torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor
1119-
)
1120-
case torch.uint8:
1121-
quantized_relu = (
1122-
torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor
1123-
)
1124-
case _:
1125-
quantized_relu = torch.ops.cadence.quantized_relu_per_tensor
1097+
match dtype:
1098+
case torch.int8:
1099+
quantized_relu = (
1100+
torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor
1101+
)
1102+
case torch.uint8:
1103+
quantized_relu = (
1104+
torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor
1105+
)
1106+
case _:
1107+
quantized_relu = torch.ops.cadence.quantized_relu_per_tensor
11261108

1127-
output = quantized_relu(
1128-
X,
1129-
X_zero_point,
1130-
out_zero_point,
1131-
out_multiplier,
1132-
out_shift,
1133-
)
1134-
else:
1135-
output = torch.ops.cadence.quantized_relu(
1136-
X, X_zero_point, out_zero_point, out_multiplier, out_shift
1137-
)
1109+
output = quantized_relu(
1110+
X,
1111+
X_zero_point,
1112+
out_zero_point,
1113+
out_multiplier,
1114+
out_shift,
1115+
)
11381116

11391117
# Verify output properties
11401118
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")

0 commit comments

Comments
 (0)