diff --git a/python/core/model_processing/converters/onnx_converter.py b/python/core/model_processing/converters/onnx_converter.py index 0037d807..936d1d1d 100644 --- a/python/core/model_processing/converters/onnx_converter.py +++ b/python/core/model_processing/converters/onnx_converter.py @@ -1022,7 +1022,9 @@ def _extract_model_io_info( ] self.input_shape = get_input_shapes(onnx_model) - def get_weights(self: ONNXConverter) -> tuple[ + def get_weights( + self: ONNXConverter, + ) -> tuple[ dict[str, list[ONNXLayerDict]], dict[str, list[ONNXLayerDict]], CircuitParamsDict, @@ -1055,7 +1057,7 @@ def get_weights(self: ONNXConverter) -> tuple[ scale_base=scale_base, ) # Get layers in correct format - (architecture, w_and_b) = self.analyze_layers( + architecture, w_and_b = self.analyze_layers( scaled_and_transformed_model, output_name_to_shape, ) diff --git a/python/core/model_processing/onnx_quantizer/layers/squeeze.py b/python/core/model_processing/onnx_quantizer/layers/squeeze.py new file mode 100644 index 00000000..60106110 --- /dev/null +++ b/python/core/model_processing/onnx_quantizer/layers/squeeze.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +import numpy as np +from onnx import numpy_helper + +from python.core.model_processing.onnx_quantizer.exceptions import ( + InvalidParamError, +) +from python.core.model_processing.onnx_quantizer.layers.base import ( + BaseOpQuantizer, + QuantizerBase, + ScaleConfig, +) + +if TYPE_CHECKING: + import onnx + + +class QuantizeSqueeze(QuantizerBase): + OP_TYPE = "Squeeze" + DOMAIN = "" + USE_WB = False + USE_SCALING = False + # Only the data input is relevant for scale-planning. + SCALE_PLAN: ClassVar = {0: 1} + + +class SqueezeQuantizer(BaseOpQuantizer, QuantizeSqueeze): + """ + Quantizer for ONNX Squeeze. + + Squeeze is scale-preserving (pure shape/view transform): + - No arithmetic + - No rescaling + - No custom op + + We support: + - axes as an attribute (older opsets) + - axes as a constant initializer input (newer opsets) + + We do NOT support dynamic axes provided at runtime. + """ + + def __init__( + self: SqueezeQuantizer, + new_initializers: list[onnx.TensorProto] | None = None, + ) -> None: + super().__init__() + if new_initializers is not None: + self.new_initializers = new_initializers + + def quantize( + self: SqueezeQuantizer, + node: onnx.NodeProto, + graph: onnx.GraphProto, + scale_config: ScaleConfig, + initializer_map: dict[str, onnx.TensorProto], + ) -> list[onnx.NodeProto]: + # Pure passthrough; QuantizerBase handles standard bookkeeping. + return QuantizeSqueeze.quantize( + self, + node, + graph, + scale_config, + initializer_map, + ) + + _N_INPUTS_NO_AXES: ClassVar[int] = 1 + _N_INPUTS_WITH_AXES: ClassVar[int] = 2 + + def _get_axes_from_attribute(self, node: onnx.NodeProto) -> list[int] | None: + for attr in node.attribute: + if attr.name == "axes": + return list(attr.ints) + return None + + def _get_axes_from_initializer_input( + self, + node: onnx.NodeProto, + initializer_map: dict[str, onnx.TensorProto], + ) -> list[int]: + axes_name = node.input[1] + if axes_name not in initializer_map: + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=( + f"Dynamic axes input is not supported for Squeeze " + f"(expected axes '{axes_name}' to be an initializer)." + ), + ) + + axes_tensor = initializer_map[axes_name] + arr = numpy_helper.to_array(axes_tensor) + + if not np.issubdtype(arr.dtype, np.integer): + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=f"Squeeze axes initializer must be integer, got {arr.dtype}.", + attr_key="axes", + expected="integer tensor (0-D or 1-D)", + ) + + if arr.ndim == 0: + return [int(arr)] + if arr.ndim == 1: + return [int(x) for x in arr.tolist()] + + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=f"Squeeze axes initializer must be 0-D or 1-D, got {arr.ndim}-D.", + attr_key="axes", + expected="0-D scalar or 1-D list of axes", + ) + + def check_supported( + self: SqueezeQuantizer, + node: onnx.NodeProto, + initializer_map: dict[str, onnx.TensorProto] | None = None, + ) -> None: + self.validate_node_has_output(node) + initializer_map = initializer_map or {} + + n_inputs = len(node.input) + if n_inputs not in (self._N_INPUTS_NO_AXES, self._N_INPUTS_WITH_AXES): + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=f"Squeeze expects 1 or 2 inputs, got {n_inputs}.", + ) + + axes = self._get_axes_from_attribute(node) + + # If axes is provided as a second input, it must be a constant initializer. + if axes is None and n_inputs == self._N_INPUTS_WITH_AXES: + axes = self._get_axes_from_initializer_input(node, initializer_map) + + # If axes is omitted entirely, ONNX semantics are "remove all dims of size 1". + # We can't validate legality here without rank/shape; defer to Rust. + + if axes is None: + return + + if len(set(axes)) != len(axes): + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=f"axes must not contain duplicates: {axes}", + attr_key="axes", + expected="axes list with unique entries", + ) diff --git a/python/core/model_processing/onnx_quantizer/layers/unsqueeze.py b/python/core/model_processing/onnx_quantizer/layers/unsqueeze.py new file mode 100644 index 00000000..2d1f6db2 --- /dev/null +++ b/python/core/model_processing/onnx_quantizer/layers/unsqueeze.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +import numpy as np +from onnx import helper, numpy_helper + +from python.core.model_processing.converters.base import ModelType +from python.core.model_processing.errors import LayerAnalysisError +from python.core.model_processing.onnx_custom_ops.onnx_helpers import parse_attributes +from python.core.model_processing.onnx_quantizer.exceptions import ( + InvalidParamError, +) +from python.core.model_processing.onnx_quantizer.layers.base import ( + BaseOpQuantizer, + QuantizerBase, + ScaleConfig, +) + +if TYPE_CHECKING: + import onnx + +_N_UNSQUEEZE_INPUTS: int = 2 + + +class QuantizeUnsqueeze(QuantizerBase): + OP_TYPE = "Unsqueeze" + DOMAIN = "" + USE_WB = False + USE_SCALING = False + # Only the data input is relevant for scale-planning. + SCALE_PLAN: ClassVar = {0: 1} + + +class UnsqueezeQuantizer(BaseOpQuantizer, QuantizeUnsqueeze): + """ + Quantizer for ONNX Unsqueeze. + + Unsqueeze is scale-preserving (pure shape/view transform): + - No arithmetic + - No rescaling + - No custom op + + Semantics: + - Inserts new dimensions of size 1 at the specified axes positions. + + We support: + - axes as an attribute (older opsets) + - axes as a constant initializer input (opset >= 13 style) + + We do NOT support dynamic axes provided at runtime. + """ + + def __init__( + self: UnsqueezeQuantizer, + new_initializers: list[onnx.TensorProto] | None = None, + ) -> None: + super().__init__() + if new_initializers is not None: + self.new_initializers = new_initializers + + def quantize( + self: UnsqueezeQuantizer, + node: onnx.NodeProto, + graph: onnx.GraphProto, + scale_config: ScaleConfig, + initializer_map: dict[str, onnx.TensorProto], + ) -> list[onnx.NodeProto]: + # Pure passthrough; QuantizerBase handles standard bookkeeping. + return QuantizeUnsqueeze.quantize( + self, + node, + graph, + scale_config, + initializer_map, + ) + + def pre_analysis_transform( + self: UnsqueezeQuantizer, + node: onnx.NodeProto, + graph: onnx.GraphProto, + initializer_map: dict[str, onnx.TensorProto], + scale_base: int, + scale_exponent: int, + ) -> None: + _ = initializer_map, scale_base, scale_exponent + model_type = ModelType.ONNX + params = parse_attributes(node.attribute) + if node.op_type != "Unsqueeze": + return + if params and "axes" in params: + return + axes = _extract_unsqueeze_axes_into_params( + name=node.name, + inputs=node.input, + params=params, + graph=graph, + model_type=model_type, + initializer_map=initializer_map, + ) + attr = helper.make_attribute("axes", axes["axes"]) + node.attribute.append(attr) + + _N_INPUTS_NO_AXES: ClassVar[int] = 1 + _N_INPUTS_WITH_AXES: ClassVar[int] = 2 + + def _get_axes_from_attribute(self, node: onnx.NodeProto) -> list[int] | None: + for attr in node.attribute: + if attr.name == "axes": + return list(attr.ints) + return None + + def _get_axes_from_initializer_input( + self, + node: onnx.NodeProto, + initializer_map: dict[str, onnx.TensorProto], + ) -> list[int]: + axes_name = node.input[1] + if axes_name not in initializer_map: + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=( + f"Dynamic axes input is not supported for Unsqueeze " + f"(expected axes '{axes_name}' to be an initializer)." + ), + ) + + axes_tensor = initializer_map[axes_name] + arr = numpy_helper.to_array(axes_tensor) + + if not np.issubdtype(arr.dtype, np.integer): + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=f"Unsqueeze axes initializer must be integer, got {arr.dtype}.", + attr_key="axes", + expected="integer tensor (0-D or 1-D)", + ) + + if arr.ndim == 0: + return [int(arr)] + if arr.ndim == 1: + return [int(x) for x in arr.tolist()] + + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=f"Unsqueeze axes initializer must be 0-D or 1-D, got {arr.ndim}-D.", + attr_key="axes", + expected="0-D scalar or 1-D list of axes", + ) + + def check_supported( + self: UnsqueezeQuantizer, + node: onnx.NodeProto, + initializer_map: dict[str, onnx.TensorProto] | None = None, + ) -> None: + self.validate_node_has_output(node) + initializer_map = initializer_map or {} + + n_inputs = len(node.input) + if n_inputs not in (self._N_INPUTS_NO_AXES, self._N_INPUTS_WITH_AXES): + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=( + "Unsqueeze expects either 1 input (axes as attribute) or 2 inputs " + f"(axes as initializer), got {n_inputs}." + ), + ) + + axes = self._get_axes_from_attribute(node) + + # ONNX Unsqueeze has two schema styles: + # - newer: Unsqueeze(data, axes) -> 2 inputs, axes is initializer input + # - older: Unsqueeze(data) with axes attribute -> 1 input + if n_inputs == self._N_INPUTS_NO_AXES: + if axes is None: + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=( + "Unsqueeze with 1 input is only supported when 'axes' is " + "provided as an attribute (older opsets)." + ), + attr_key="axes", + expected="axes attribute", + ) + elif axes is None: + axes = self._get_axes_from_initializer_input(node, initializer_map) + + if axes is None: + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message="Unsqueeze requires 'axes' to be provided.", + attr_key="axes", + expected="axes attribute or initializer input", + ) + + if len(set(axes)) != len(axes): + raise InvalidParamError( + node_name=node.name, + op_type=node.op_type, + message=f"axes must not contain duplicates: {axes}", + attr_key="axes", + expected="axes list with unique entries", + ) + + +def _extract_unsqueeze_axes_into_params( # noqa: PLR0913 + *, + name: str, + inputs: list[str] | tuple[str, ...], + params: dict | None, + graph: onnx.GraphProto, + model_type: ModelType, + initializer_map: dict[str, onnx.TensorProto] | None = None, +) -> dict: + if len(inputs) != _N_UNSQUEEZE_INPUTS: + msg = ( + f"Unsqueeze '{name}' is missing axes input. " + f"Expected 2 inputs (data, axes), got {len(inputs)}: {list(inputs)}" + ) + raise LayerAnalysisError(model_type=model_type, reason=msg) + + axes_name = inputs[1] + + axes_arr = _resolve_unsqueeze_axes_array( + name=name, + axes_name=axes_name, + graph=graph, + model_type=model_type, + initializer_map=initializer_map, + ) + + _validate_unsqueeze_axes_are_integer( + name=name, + axes_arr=axes_arr, + model_type=model_type, + ) + + out_params = params or {} + out_params["axes"] = _axes_array_to_int_list(axes_arr) + return out_params + + +def _resolve_unsqueeze_axes_array( + *, + name: str, + axes_name: str, + graph: onnx.GraphProto, + model_type: ModelType, + initializer_map: dict[str, onnx.TensorProto] | None = None, +) -> np.ndarray: + if not initializer_map: + initializer_map = {init.name: init for init in graph.initializer} + + if axes_name in initializer_map: + return numpy_helper.to_array(initializer_map[axes_name]) + + const_tensor = _find_constant_tensor_by_output_name( + graph=graph, + output_name=axes_name, + ) + + if const_tensor is not None: + return numpy_helper.to_array(const_tensor) + + msg = ( + f"Unsqueeze '{name}' has dynamic axes input '{axes_name}'. " + "Only constant initializer axes or Constant-node axes are supported." + ) + raise LayerAnalysisError(model_type=model_type, reason=msg) + + +def _find_constant_tensor_by_output_name( + *, + graph: onnx.GraphProto, + output_name: str, +) -> onnx.TensorProto | None: + for n in graph.node: + if n.op_type != "Constant" or not n.output: + continue + if n.output[0] != output_name: + continue + + for attr in n.attribute: + if attr.name == "value" and attr.t is not None: + return attr.t + + # Constant node exists but doesn't have the expected tensor attribute. + return None + + return None + + +def _validate_unsqueeze_axes_are_integer( + *, + name: str, + axes_arr: np.ndarray, + model_type: ModelType, +) -> None: + if not np.issubdtype(axes_arr.dtype, np.integer): + msg = f"Unsqueeze '{name}' axes must be integer, got dtype {axes_arr.dtype}." + raise LayerAnalysisError(model_type=model_type, reason=msg) + + +def _axes_array_to_int_list(axes_arr: np.ndarray) -> list[int]: + if axes_arr.ndim == 0: + return [int(axes_arr)] + return [int(x) for x in axes_arr.reshape(-1).tolist()] diff --git a/python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py b/python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py index 7334183d..d8e9a7bd 100644 --- a/python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +++ b/python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py @@ -31,7 +31,11 @@ from python.core.model_processing.onnx_quantizer.layers.min import MinQuantizer from python.core.model_processing.onnx_quantizer.layers.mul import MulQuantizer from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer +from python.core.model_processing.onnx_quantizer.layers.squeeze import SqueezeQuantizer from python.core.model_processing.onnx_quantizer.layers.sub import SubQuantizer +from python.core.model_processing.onnx_quantizer.layers.unsqueeze import ( + UnsqueezeQuantizer, +) class ONNXOpQuantizer: @@ -90,6 +94,8 @@ def __init__(self: ONNXOpQuantizer) -> None: self.register("Max", MaxQuantizer(self.new_initializers)) self.register("Min", MinQuantizer(self.new_initializers)) self.register("BatchNormalization", BatchnormQuantizer(self.new_initializers)) + self.register("Squeeze", SqueezeQuantizer(self.new_initializers)) + self.register("Unsqueeze", UnsqueezeQuantizer(self.new_initializers)) def register( self: ONNXOpQuantizer, diff --git a/python/tests/onnx_quantizer_tests/layers/base.py b/python/tests/onnx_quantizer_tests/layers/base.py index 72daa2f9..12ccbdfc 100644 --- a/python/tests/onnx_quantizer_tests/layers/base.py +++ b/python/tests/onnx_quantizer_tests/layers/base.py @@ -98,7 +98,7 @@ def create_initializers( combined_inits = {**self.required_initializers, **initializer_overrides} for name, data in combined_inits.items(): # Special handling for shape tensors in Reshape, etc. - if name == "shape": + if name in {"shape", "axes"}: tensor = numpy_helper.from_array(data.astype(np.int64), name=name) else: tensor = numpy_helper.from_array(data.astype(np.float32), name=name) diff --git a/python/tests/onnx_quantizer_tests/layers/squeeze_config.py b/python/tests/onnx_quantizer_tests/layers/squeeze_config.py new file mode 100644 index 00000000..ab296e5a --- /dev/null +++ b/python/tests/onnx_quantizer_tests/layers/squeeze_config.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import numpy as np + +from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError +from python.tests.onnx_quantizer_tests.layers.base import ( + e2e_test, + error_test, + valid_test, +) +from python.tests.onnx_quantizer_tests.layers.factory import ( + BaseLayerConfigProvider, + LayerTestConfig, +) + + +class SqueezeConfigProvider(BaseLayerConfigProvider): + """Test configuration provider for Squeeze""" + + @property + def layer_name(self) -> str: + return "Squeeze" + + def get_config(self) -> LayerTestConfig: + # Test opset-newer form: Squeeze(data, axes) where axes is an int64 initializer. + return LayerTestConfig( + op_type="Squeeze", + valid_inputs=["A", "axes"], + valid_attributes={}, # no attribute-based axes + required_initializers={}, + input_shapes={ + "A": [1, 3, 1, 5], + # "axes" is removed from graph inputs automatically + # when it is an initializer. + "axes": [2], + }, + output_shapes={ + "squeeze_output": [3, 5], + }, + ) + + def get_test_specs(self) -> list: + + return [ + # --- VALID TESTS --- + valid_test("axes_omitted") + .description("Squeeze with no axes input: removes all dims of size 1") + .override_inputs("A") # only data input + .override_input_shapes(A=[1, 3, 1, 5]) + .override_output_shapes(squeeze_output=[3, 5]) + .tags("basic", "squeeze", "axes_omitted") + .build(), + valid_test("axes_init_basic") + .description("Squeeze with axes initializer [0,2] on [1,3,1,5] -> [3,5]") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([0, 2], dtype=np.int64)) + .override_input_shapes(A=[1, 3, 1, 5]) + .override_output_shapes(squeeze_output=[3, 5]) + .tags("basic", "squeeze", "axes_initializer") + .build(), + valid_test("axes_init_singleton_middle") + .description("Squeeze with axes initializer [1] on [2,1,4] -> [2,4]") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([1], dtype=np.int64)) + .override_input_shapes(A=[2, 1, 4]) + .override_output_shapes(squeeze_output=[2, 4]) + .tags("squeeze", "axes_initializer") + .build(), + valid_test("axes_init_negative") + .description("Squeeze with axes initializer [-2] on [2,1,4] -> [2,4]") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([-2], dtype=np.int64)) + .override_input_shapes(A=[2, 1, 4]) + .override_output_shapes(squeeze_output=[2, 4]) + .tags("squeeze", "axes_initializer", "negative_axis") + .build(), + # --- ERROR TESTS --- + error_test("duplicate_axes_init") + .description("Duplicate axes in initializer should be rejected") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([1, 1], dtype=np.int64)) + .override_input_shapes(A=[2, 1, 4]) + .override_output_shapes(squeeze_output=[2, 4]) + .expects_error(InvalidParamError, match="axes must not contain duplicates") + .tags("error", "squeeze", "axes_initializer") + .build(), + error_test("dynamic_axes_input_not_supported") + .description( + "Squeeze with runtime axes (2 inputs but axes is NOT an initializer) " + "should be rejected", + ) + .override_inputs("A", "axes") # axes provided as graph input (unsupported) + .override_input_shapes(A=[1, 3, 1, 5], axes=[2]) + .override_output_shapes(squeeze_output=[3, 5]) + .expects_error( + InvalidParamError, + match="Dynamic axes input is not supported", + ) + .tags("error", "squeeze", "axes_input") + .build(), + # --- E2E TESTS --- + e2e_test("e2e_axes_omitted") + .description("End-to-end Squeeze test (axes omitted)") + .override_inputs("A") + .override_input_shapes(A=[1, 3, 1, 5]) + .override_output_shapes(squeeze_output=[3, 5]) + .tags("e2e", "squeeze") + .build(), + e2e_test("e2e_axes_init") + .description("End-to-end Squeeze test (axes initializer)") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([0, 2], dtype=np.int64)) + .override_input_shapes(A=[1, 3, 1, 5]) + .override_output_shapes(squeeze_output=[3, 5]) + .tags("e2e", "squeeze", "axes_initializer") + .build(), + ] diff --git a/python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py b/python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py new file mode 100644 index 00000000..7c6e7766 --- /dev/null +++ b/python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import numpy as np + +from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError +from python.tests.onnx_quantizer_tests.layers.base import ( + e2e_test, + error_test, + valid_test, +) +from python.tests.onnx_quantizer_tests.layers.factory import ( + BaseLayerConfigProvider, + LayerTestConfig, +) + + +class UnsqueezeConfigProvider(BaseLayerConfigProvider): + """Test configuration provider for Unsqueeze""" + + @property + def layer_name(self) -> str: + return "Unsqueeze" + + def get_config(self) -> LayerTestConfig: + # Test opset-newer form: Unsqueeze(data, axes) + # where axes is an int64 initializer. + return LayerTestConfig( + op_type="Unsqueeze", + valid_inputs=["A", "axes"], + valid_attributes={}, # no attribute-based axes + required_initializers={}, + input_shapes={ + "A": [3, 5], + # "axes" will be removed from graph inputs automatically + # when it is an initializer. + "axes": [2], + }, + output_shapes={ + "unsqueeze_output": [1, 3, 1, 5], + }, + ) + + def get_test_specs(self) -> list: + + return [ + # --- VALID TESTS --- + valid_test("axes_init_basic") + .description("Unsqueeze with axes initializer [0,2] on [3,5] -> [1,3,1,5]") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([0, 2], dtype=np.int64)) + .override_input_shapes(A=[3, 5]) + .override_output_shapes(unsqueeze_output=[1, 3, 1, 5]) + .tags("basic", "unsqueeze", "axes_initializer") + .build(), + valid_test("axes_init_single_axis") + .description("Unsqueeze with axes initializer [1] on [3,5] -> [3,1,5]") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([1], dtype=np.int64)) + .override_input_shapes(A=[3, 5]) + .override_output_shapes(unsqueeze_output=[3, 1, 5]) + .tags("unsqueeze", "axes_initializer") + .build(), + valid_test("axes_init_negative") + .description("Unsqueeze with negative axis [-1] on [3,5] -> [3,5,1]") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([-1], dtype=np.int64)) + .override_input_shapes(A=[3, 5]) + .override_output_shapes(unsqueeze_output=[3, 5, 1]) + .tags("unsqueeze", "axes_initializer", "negative_axis") + .build(), + valid_test("axes_init_two_axes_append") + .description("Unsqueeze with axes [2,3] on [3,5] -> [3,5,1,1]") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([2, 3], dtype=np.int64)) + .override_input_shapes(A=[3, 5]) + .override_output_shapes(unsqueeze_output=[3, 5, 1, 1]) + .tags("unsqueeze", "axes_initializer") + .build(), + # --- ERROR TESTS --- + error_test("duplicate_axes_init") + .description("Duplicate axes in initializer should be rejected") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([1, 1], dtype=np.int64)) + .override_input_shapes(A=[3, 5]) + .override_output_shapes( + unsqueeze_output=[3, 1, 5], + ) # not used; kept consistent + .expects_error(InvalidParamError, match="axes must not contain duplicates") + .tags("error", "unsqueeze", "axes_initializer") + .build(), + error_test("dynamic_axes_input_not_supported") + .description( + "Unsqueeze with runtime axes (2 inputs but axes is NOT an initializer) " + "should be rejected", + ) + .override_inputs("A", "axes") # axes provided as graph input (unsupported) + .override_input_shapes(A=[3, 5], axes=[2]) + .override_output_shapes(unsqueeze_output=[1, 3, 1, 5]) + .expects_error( + InvalidParamError, + match="Dynamic axes input is not supported", + ) + .tags("error", "unsqueeze", "axes_input") + .build(), + # --- E2E TESTS --- + e2e_test("e2e_axes_init") + .description("End-to-end Unsqueeze test (axes initializer)") + .override_inputs("A", "axes") + .override_initializer("axes", np.array([0, 2], dtype=np.int64)) + .override_input_shapes(A=[3, 5]) + .override_output_shapes(unsqueeze_output=[1, 3, 1, 5]) + .tags("e2e", "unsqueeze", "axes_initializer") + .build(), + ] diff --git a/rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs b/rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs index 8f0c8165..4ba5b828 100644 --- a/rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs +++ b/rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs @@ -19,6 +19,8 @@ use crate::circuit_functions::layers::maxpool::MaxPoolLayer; use crate::circuit_functions::layers::min::MinLayer; use crate::circuit_functions::layers::relu::ReluLayer; use crate::circuit_functions::layers::reshape::ReshapeLayer; +use crate::circuit_functions::layers::squeeze::SqueezeLayer; +use crate::circuit_functions::layers::unsqueeze::UnsqueezeLayer; use expander_compiler::frontend::{Config, RootAPI}; use std::str::FromStr; @@ -144,4 +146,6 @@ define_layers! { Min => { name: "Min", builder: MinLayer::build }, ReLU => { name: "ReLU", builder: ReluLayer::build, aliases: ["Relu"] }, Reshape => { name: "Reshape", builder: ReshapeLayer::build }, + Squeeze => { name: "Squeeze", builder: SqueezeLayer::build }, + Unsqueeze => { name: "Unsqueeze", builder: UnsqueezeLayer::build }, } diff --git a/rust/jstprove_circuits/src/circuit_functions/layers/mod.rs b/rust/jstprove_circuits/src/circuit_functions/layers/mod.rs index d991cb54..20148b64 100644 --- a/rust/jstprove_circuits/src/circuit_functions/layers/mod.rs +++ b/rust/jstprove_circuits/src/circuit_functions/layers/mod.rs @@ -14,7 +14,9 @@ pub mod min; pub mod mul; pub mod relu; pub mod reshape; +pub mod squeeze; pub mod sub; +pub mod unsqueeze; pub use errors::LayerError; pub use layer_kinds::LayerKind; diff --git a/rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs b/rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs new file mode 100644 index 00000000..4525a530 --- /dev/null +++ b/rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs @@ -0,0 +1,181 @@ +use std::collections::{HashMap, HashSet}; + +use expander_compiler::frontend::{Config, RootAPI, Variable}; +use ndarray::{ArrayD, IxDyn}; + +use crate::circuit_functions::{ + CircuitError, + layers::{LayerError, LayerKind, layer_ops::LayerOp}, + utils::{ + constants::{AXES, INPUT}, + onnx_model::{extract_params_and_expected_shape, get_input_name, get_param}, + }, +}; + +#[allow(dead_code)] +#[derive(Debug)] +pub struct SqueezeLayer { + name: String, + axes: Option>, + input_shape: Vec, + inputs: Vec, + outputs: Vec, +} + +impl SqueezeLayer { + fn normalize_axes(&self, axes: &[i64], rank: usize) -> Result, CircuitError> { + let rank_i64 = i64::try_from(rank).map_err(|_| LayerError::InvalidParameterValue { + layer: LayerKind::Squeeze, + layer_name: self.name.clone(), + param_name: AXES.into(), + value: format!("rank {rank} cannot be represented as i64"), + })?; + + let mut out: Vec = Vec::with_capacity(axes.len()); + let mut seen: HashSet = HashSet::new(); + + for &a in axes { + let ax_i64 = if a < 0 { a + rank_i64 } else { a }; + + if ax_i64 < 0 || ax_i64 >= rank_i64 { + return Err(LayerError::InvalidParameterValue { + layer: LayerKind::Squeeze, + layer_name: self.name.clone(), + param_name: AXES.into(), + value: format!("axis {a} out of range for rank {rank}"), + } + .into()); + } + + let ax = usize::try_from(ax_i64).map_err(|_| LayerError::InvalidParameterValue { + layer: LayerKind::Squeeze, + layer_name: self.name.clone(), + param_name: AXES.into(), + value: format!("axis {a} is not a valid usize index after normalization"), + })?; + + if !seen.insert(ax) { + return Err(LayerError::InvalidParameterValue { + layer: LayerKind::Squeeze, + layer_name: self.name.clone(), + param_name: AXES.into(), + value: format!("duplicate axis {ax} in axes={axes:?}"), + } + .into()); + } + + out.push(ax); + } + + out.sort_unstable(); + Ok(out) + } + + fn squeezed_shape( + &self, + input_shape: &[usize], + axes: Option<&Vec>, + ) -> Result, CircuitError> { + let rank = input_shape.len(); + + // axes omitted => remove all dims of size 1 + if axes.is_none() { + let shape: Vec = input_shape.iter().copied().filter(|&d| d != 1).collect(); + // ONNX allows 0-D output if everything is squeezed + // ndarray supports IxDyn(&[]) + return Ok(shape); + } + + let axes_ref = axes.expect("axes is Some here"); + let axes_u = self.normalize_axes(axes_ref.as_slice(), rank)?; + + let axes_set: HashSet = axes_u.iter().copied().collect(); + + // Validate specified axes are actually squeezable + for &ax in &axes_u { + let dim = input_shape[ax]; + if dim != 1 { + return Err(LayerError::InvalidShape { + layer: LayerKind::Squeeze, + msg: format!( + "cannot squeeze axis {ax}: expected dim==1, got {dim} (shape={input_shape:?})" + ), + } + .into()); + } + } + + let out_shape: Vec = input_shape + .iter() + .enumerate() + .filter_map(|(i, &d)| if axes_set.contains(&i) { None } else { Some(d) }) + .collect(); + + Ok(out_shape) + } +} + +impl> LayerOp for SqueezeLayer { + fn apply( + &self, + _api: &mut Builder, + input: HashMap>, + ) -> Result<(Vec, ArrayD), CircuitError> { + let input_name = get_input_name(&self.inputs, 0, LayerKind::Squeeze, INPUT)?; + let layer_input = input + .get(&input_name.clone()) + .ok_or_else(|| LayerError::MissingInput { + layer: LayerKind::Squeeze, + name: input_name.clone(), + })? + .clone(); + + let in_shape: Vec = layer_input.shape().to_vec(); + let out_shape = self.squeezed_shape(&in_shape, self.axes.as_ref())?; + + // Reshape without changing element order. + let flat: Vec = layer_input.iter().copied().collect(); + + let out = ArrayD::from_shape_vec(IxDyn(&out_shape), flat).map_err(|e| { + LayerError::InvalidShape { + layer: LayerKind::Squeeze, + msg: format!("failed to reshape in Squeeze: {e} (out_shape={out_shape:?})"), + } + })?; + + Ok((self.outputs.clone(), out)) + } + + fn build( + layer: &crate::circuit_functions::utils::onnx_types::ONNXLayer, + _circuit_params: &crate::circuit_functions::utils::onnx_model::CircuitParams, + _optimization_pattern: crate::circuit_functions::utils::graph_pattern_matching::PatternRegistry, + _is_rescale: bool, + _index: usize, + layer_context: &crate::circuit_functions::utils::build_layers::BuildLayerContext, + ) -> Result>, CircuitError> { + let (params, expected_shape) = extract_params_and_expected_shape(layer_context, layer) + .map_err(|e| LayerError::Other { + layer: LayerKind::Squeeze, + msg: format!("extract_params_and_expected_shape failed: {e}"), + })?; + + // axes may be missing (axes omitted semantics) + // When present, parse_attributes on Python side should serialize it as a list. + // `get_param` will error if missing, so we only call it if the key exists. + let axes: Option> = match params.get(AXES) { + Some(_) => Some(get_param(&layer.name, AXES, ¶ms)?), + None => None, + }; + + let squeeze = Self { + name: layer.name.clone(), + axes, + input_shape: expected_shape.clone(), + inputs: layer.inputs.clone(), + outputs: layer.outputs.clone(), + }; + + Ok(Box::new(squeeze)) + } +} diff --git a/rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs b/rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs new file mode 100644 index 00000000..252ed4b4 --- /dev/null +++ b/rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs @@ -0,0 +1,174 @@ +use std::collections::{HashMap, HashSet}; + +use expander_compiler::frontend::{Config, RootAPI, Variable}; +use ndarray::{ArrayD, IxDyn}; + +use crate::circuit_functions::{ + CircuitError, + layers::{LayerError, LayerKind, layer_ops::LayerOp}, + utils::{ + constants::{AXES, INPUT}, + onnx_model::{extract_params_and_expected_shape, get_input_name, get_param}, + }, +}; + +#[allow(dead_code)] +#[derive(Debug)] +pub struct UnsqueezeLayer { + name: String, + axes: Vec, // Unsqueeze requires axes + input_shape: Vec, + inputs: Vec, + outputs: Vec, +} + +impl UnsqueezeLayer { + fn normalize_axes(&self, rank_out: usize) -> Result, CircuitError> { + let mut out: Vec = Vec::with_capacity(self.axes.len()); + let mut seen: HashSet = HashSet::new(); + + for &a in &self.axes { + let rank_i64 = + i64::try_from(rank_out).map_err(|_| LayerError::InvalidParameterValue { + layer: LayerKind::Unsqueeze, + layer_name: self.name.clone(), + param_name: AXES.into(), + value: format!("rank {rank_out} cannot be represented as i64"), + })?; + + let ax_i64 = if a < 0 { a + rank_i64 } else { a }; + + if ax_i64 < 0 || ax_i64 >= rank_i64 { + return Err(LayerError::InvalidParameterValue { + layer: LayerKind::Unsqueeze, + layer_name: self.name.clone(), + param_name: AXES.into(), + value: format!("axis {a} out of range for rank {rank_out}"), + } + .into()); + } + + let ax = usize::try_from(ax_i64).map_err(|_| LayerError::InvalidParameterValue { + layer: LayerKind::Unsqueeze, + layer_name: self.name.clone(), + param_name: AXES.into(), + value: format!("axis {a} is not a valid usize index after normalization"), + })?; + + if !seen.insert(ax) { + return Err(LayerError::InvalidParameterValue { + layer: LayerKind::Unsqueeze, + layer_name: self.name.clone(), + param_name: AXES.into(), + value: format!("duplicate axis {ax} in axes={:?}", self.axes), + } + .into()); + } + + out.push(ax); + } + + out.sort_unstable(); + Ok(out) + } + + fn unsqueezed_shape(&self, input_shape: &[usize]) -> Result, CircuitError> { + let rank_in = input_shape.len(); + let rank_out = rank_in + self.axes.len(); + + let axes_u = self.normalize_axes(rank_out)?; + let axes_set: HashSet = axes_u.iter().copied().collect(); + + let mut out_shape: Vec = Vec::with_capacity(rank_out); + let mut j: usize = 0; + + for i in 0..rank_out { + if axes_set.contains(&i) { + out_shape.push(1); + } else { + let dim = *input_shape.get(j).ok_or_else(|| LayerError::InvalidShape { + layer: LayerKind::Unsqueeze, + msg: format!( + "ran out of input dims when building output shape (input_shape={input_shape:?}, axes={:?})", + self.axes + ), + })?; + out_shape.push(dim); + j += 1; + } + } + + if j != rank_in { + return Err(LayerError::InvalidShape { + layer: LayerKind::Unsqueeze, + msg: format!( + "did not consume all input dims (consumed={j}, rank_in={rank_in}); input_shape={input_shape:?}, axes={:?}", + self.axes + ), + } + .into()); + } + + Ok(out_shape) + } +} + +impl> LayerOp for UnsqueezeLayer { + fn apply( + &self, + _api: &mut Builder, + input: HashMap>, + ) -> Result<(Vec, ArrayD), CircuitError> { + let input_name = get_input_name(&self.inputs, 0, LayerKind::Unsqueeze, INPUT)?; + let layer_input = input + .get(input_name) + .ok_or_else(|| LayerError::MissingInput { + layer: LayerKind::Unsqueeze, + name: input_name.clone(), + })? + .clone(); + + let in_shape: Vec = layer_input.shape().to_vec(); + let out_shape = self.unsqueezed_shape(&in_shape)?; + + // Reshape without changing element order. + let flat: Vec = layer_input.iter().copied().collect(); + + let out = ArrayD::from_shape_vec(IxDyn(&out_shape), flat).map_err(|e| { + LayerError::InvalidShape { + layer: LayerKind::Unsqueeze, + msg: format!("failed to reshape in Unsqueeze: {e} (out_shape={out_shape:?})"), + } + })?; + + Ok((self.outputs.clone(), out)) + } + + fn build( + layer: &crate::circuit_functions::utils::onnx_types::ONNXLayer, + _circuit_params: &crate::circuit_functions::utils::onnx_model::CircuitParams, + _optimization_pattern: crate::circuit_functions::utils::graph_pattern_matching::PatternRegistry, + _is_rescale: bool, + _index: usize, + layer_context: &crate::circuit_functions::utils::build_layers::BuildLayerContext, + ) -> Result>, CircuitError> { + let (params, expected_shape) = extract_params_and_expected_shape(layer_context, layer) + .map_err(|e| LayerError::Other { + layer: LayerKind::Unsqueeze, + msg: format!("extract_params_and_expected_shape failed: {e}"), + })?; + + // Unsqueeze requires axes. + let axes: Vec = get_param(&layer.name, AXES, ¶ms)?; + + let unsqueeze = Self { + name: layer.name.clone(), + axes, + input_shape: expected_shape.clone(), + inputs: layer.inputs.clone(), + outputs: layer.outputs.clone(), + }; + + Ok(Box::new(unsqueeze)) + } +} diff --git a/rust/jstprove_circuits/src/circuit_functions/utils/constants.rs b/rust/jstprove_circuits/src/circuit_functions/utils/constants.rs index d4cebc63..61604045 100644 --- a/rust/jstprove_circuits/src/circuit_functions/utils/constants.rs +++ b/rust/jstprove_circuits/src/circuit_functions/utils/constants.rs @@ -12,6 +12,9 @@ pub const GEMM: &str = "Gemm"; /// Value for Constant layer pub const VALUE: &str = "value"; +/// AXES for squeezing/unsqueezing/reductions, etc. +pub const AXES: &str = "axes"; + /// AXIS for reshaping pub const AXIS: &str = "axis";