Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 10 additions & 39 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,75 +1125,46 @@ def quantized_relu_common(


def quantized_relu_variant(
per_tensor: bool,
dtype: torch.dtype | None = None,
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
"""Create a quantized relu variant with type checking."""

def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
def variant(
X: torch.Tensor,
X_zero_point: torch.Tensor | int,
X_zero_point: int,
out_zero_point: int,
out_multiplier: torch.Tensor | int,
out_shift: torch.Tensor | int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
if per_tensor:
if dtype and X.dtype != dtype:
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")

assert isinstance(out_shift, int)
assert isinstance(out_multiplier, int)
_out_shift = out_shift
_out_multiplier = out_multiplier
else:
assert isinstance(out_multiplier, torch.Tensor)
if out_multiplier.numel() > 1:
raise ValueError("Only scalar out_multiplier is supported")

assert isinstance(out_shift, torch.Tensor)
if out_shift.numel() > 1:
raise ValueError("Only scalar out_shift is supported")

assert isinstance(X_zero_point, torch.Tensor)
if X_zero_point.shape != X.shape:
raise ValueError(
f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}"
)

_out_multiplier = int(out_multiplier.item())
_out_shift = int(out_shift.item())
if dtype and X.dtype != dtype:
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")

return quantized_relu_common(
X,
X_zero_point,
out_zero_point,
_out_multiplier,
_out_shift,
out_multiplier,
out_shift,
)

return variant

return decorator


@impl(m, "quantized_relu")
@quantized_relu_variant(False)
def quantized_relu() -> torch.Tensor: ...


@impl(m, "quantized_relu.per_tensor")
@quantized_relu_variant(True)
@quantized_relu_variant()
def quantized_relu_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_relu_asym8s_asym8s.per_tensor")
@quantized_relu_variant(True, torch.int8)
@quantized_relu_variant(torch.int8)
def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_relu_asym8u_asym8u.per_tensor")
@quantized_relu_variant(True, torch.uint8)
@quantized_relu_variant(torch.uint8)
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...


Expand Down
64 changes: 21 additions & 43 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,61 +1080,39 @@ def test_quantized_conv_per_tensor(
)
for dtype in [torch.uint8]
],
# Test case 4: Non-per-tensor
*[
(
"non_per_tensor",
torch.tensor([-1, -2, -3, 1, 2, 3], dtype=dtype), # input
torch.tensor([0, 0, 0, 1, 1, 1]), # X_zero_point
5, # out_zero_point
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
torch.tensor([1]), # out_shift (multiply by 2^1 = 2)
dtype, # dtype
torch.tensor([5, 5, 5, 5, 4, 3], dtype=dtype),
)
for dtype in [torch.int8]
],
]
)
def test_quantized_relu(
self,
name: str,
X: torch.Tensor,
X_zero_point: torch.Tensor | int,
X_zero_point: int,
out_zero_point: int,
out_multiplier: torch.Tensor | int,
out_shift: torch.Tensor | int,
out_multiplier: int,
out_shift: int,
dtype: torch.dtype,
expected_output: torch.Tensor,
) -> None:

if isinstance(X_zero_point, int):
assert isinstance(out_multiplier, int)
assert isinstance(out_shift, int)

match dtype:
case torch.int8:
quantized_relu = (
torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor
)
case torch.uint8:
quantized_relu = (
torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor
)
case _:
quantized_relu = torch.ops.cadence.quantized_relu_per_tensor
match dtype:
case torch.int8:
quantized_relu = (
torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor
)
case torch.uint8:
quantized_relu = (
torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor
)
case _:
quantized_relu = torch.ops.cadence.quantized_relu_per_tensor

output = quantized_relu(
X,
X_zero_point,
out_zero_point,
out_multiplier,
out_shift,
)
else:
output = torch.ops.cadence.quantized_relu(
X, X_zero_point, out_zero_point, out_multiplier, out_shift
)
output = quantized_relu(
X,
X_zero_point,
out_zero_point,
out_multiplier,
out_shift,
)

# Verify output properties
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")
Expand Down
Loading