@@ -1125,75 +1125,46 @@ def quantized_relu_common(
11251125
11261126
11271127def quantized_relu_variant (
1128- per_tensor : bool ,
11291128 dtype : torch .dtype | None = None ,
11301129) -> Callable [[Callable [..., torch .Tensor ]], Callable [..., torch .Tensor ]]:
11311130 """Create a quantized relu variant with type checking."""
11321131
11331132 def decorator (_ : Callable [..., torch .Tensor ]) -> Callable [..., torch .Tensor ]:
11341133 def variant (
11351134 X : torch .Tensor ,
1136- X_zero_point : torch . Tensor | int ,
1135+ X_zero_point : int ,
11371136 out_zero_point : int ,
1138- out_multiplier : torch . Tensor | int ,
1139- out_shift : torch . Tensor | int ,
1137+ out_multiplier : int ,
1138+ out_shift : int ,
11401139 ) -> torch .Tensor :
1141- if per_tensor :
1142- if dtype and X .dtype != dtype :
1143- raise ValueError (f"X dtype must be { dtype } . Got { X .dtype } " )
1144-
1145- assert isinstance (out_shift , int )
1146- assert isinstance (out_multiplier , int )
1147- _out_shift = out_shift
1148- _out_multiplier = out_multiplier
1149- else :
1150- assert isinstance (out_multiplier , torch .Tensor )
1151- if out_multiplier .numel () > 1 :
1152- raise ValueError ("Only scalar out_multiplier is supported" )
1153-
1154- assert isinstance (out_shift , torch .Tensor )
1155- if out_shift .numel () > 1 :
1156- raise ValueError ("Only scalar out_shift is supported" )
1157-
1158- assert isinstance (X_zero_point , torch .Tensor )
1159- if X_zero_point .shape != X .shape :
1160- raise ValueError (
1161- f"X_zero_point shape must be { X .shape } . Got { X_zero_point .shape } "
1162- )
1163-
1164- _out_multiplier = int (out_multiplier .item ())
1165- _out_shift = int (out_shift .item ())
1140+ if dtype and X .dtype != dtype :
1141+ raise ValueError (f"X dtype must be { dtype } . Got { X .dtype } " )
11661142
11671143 return quantized_relu_common (
11681144 X ,
11691145 X_zero_point ,
11701146 out_zero_point ,
1171- _out_multiplier ,
1172- _out_shift ,
1147+ out_multiplier ,
1148+ out_shift ,
11731149 )
11741150
11751151 return variant
11761152
11771153 return decorator
11781154
11791155
1180- @impl (m , "quantized_relu" )
1181- @quantized_relu_variant (False )
1182- def quantized_relu () -> torch .Tensor : ...
1183-
1184-
11851156@impl (m , "quantized_relu.per_tensor" )
1186- @quantized_relu_variant (True )
1157+ @quantized_relu_variant ()
11871158def quantized_relu_per_tensor () -> torch .Tensor : ...
11881159
11891160
11901161@impl (m , "quantized_relu_asym8s_asym8s.per_tensor" )
1191- @quantized_relu_variant (True , torch .int8 )
1162+ @quantized_relu_variant (torch .int8 )
11921163def quantized_relu_asym8s_asym8s_per_tensor () -> torch .Tensor : ...
11931164
11941165
11951166@impl (m , "quantized_relu_asym8u_asym8u.per_tensor" )
1196- @quantized_relu_variant (True , torch .uint8 )
1167+ @quantized_relu_variant (torch .uint8 )
11971168def quantized_relu_asym8u_asym8u_per_tensor () -> torch .Tensor : ...
11981169
11991170
0 commit comments