Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
77d3c55
fresh freivalds based on new main
tmfreiberg Dec 12, 2025
008a0e6
moved and renamed math.rs
tmfreiberg Dec 12, 2025
4cf64e8
revise docstrings
tmfreiberg Dec 12, 2025
56dbbbe
pedantic clippy docstrings and apply too long
tmfreiberg Dec 12, 2025
63310c5
unused warnings and layer_kind moved/clone
tmfreiberg Dec 12, 2025
3439bd3
Code Rabbit comments addressed
tmfreiberg Dec 15, 2025
2aff6c4
squeeze python side started
tmfreiberg Dec 16, 2025
743799d
Add Squeeze quantizer and opset-22 compatible tests
tmfreiberg Dec 17, 2025
aab328d
squeeze test errors addressed
tmfreiberg Dec 17, 2025
56e9970
squeeze rust side
tmfreiberg Dec 17, 2025
8704c34
address errors
tmfreiberg Dec 17, 2025
f5557d1
unsqueeze python side started
tmfreiberg Dec 18, 2025
e8f82c4
unsqueeze unit integration errors
tmfreiberg Dec 18, 2025
4a4656f
unsqueeze unit integration errors 2
tmfreiberg Dec 19, 2025
c984cb2
f string error
tmfreiberg Dec 19, 2025
aa1e0be
name error
tmfreiberg Dec 19, 2025
103135c
unsqueeze rust side
tmfreiberg Dec 19, 2025
be4e0b8
unsqueeze quantize rewrite
tmfreiberg Dec 19, 2025
b5f8c5a
whoops
tmfreiberg Dec 19, 2025
d667508
added special case for unsqueeze onnx converter
tmfreiberg Dec 19, 2025
5e38482
linter
tmfreiberg Dec 19, 2025
f6f53e0
removed unnecessary clone per Code Rabbit
tmfreiberg Dec 19, 2025
842ddd5
linting
tmfreiberg Dec 19, 2025
30e86a9
ruff
tmfreiberg Dec 19, 2025
2a12a5b
Merge remote-tracking branch 'origin/main' into squeeze
tmfreiberg Dec 19, 2025
c89f684
Unsqueeze: move axes extraction into pre-analysis transform; restore …
Jan 23, 2026
506729d
Unsqueeze: support Constant-node axes via pre-analysis transform
Jan 23, 2026
5e5242b
Unsqueeze: support Constant-node axes via pre-analysis transform
Jan 23, 2026
0aac17a
Move unsqueeze pre_analysis_transform into file
jsgold-1 Jan 29, 2026
d0d69a2
Fix typing
jsgold-1 Jan 29, 2026
7205fbd
Merge branch 'main' into squeeze
Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/core/model_processing/converters/onnx_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,9 @@ def _extract_model_io_info(
]
self.input_shape = get_input_shapes(onnx_model)

def get_weights(self: ONNXConverter) -> tuple[
def get_weights(
self: ONNXConverter,
) -> tuple[
dict[str, list[ONNXLayerDict]],
dict[str, list[ONNXLayerDict]],
CircuitParamsDict,
Expand Down Expand Up @@ -1055,7 +1057,7 @@ def get_weights(self: ONNXConverter) -> tuple[
scale_base=scale_base,
)
# Get layers in correct format
(architecture, w_and_b) = self.analyze_layers(
architecture, w_and_b = self.analyze_layers(
scaled_and_transformed_model,
output_name_to_shape,
)
Expand Down
155 changes: 155 additions & 0 deletions python/core/model_processing/onnx_quantizer/layers/squeeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar

import numpy as np
from onnx import numpy_helper

from python.core.model_processing.onnx_quantizer.exceptions import (
InvalidParamError,
)
from python.core.model_processing.onnx_quantizer.layers.base import (
BaseOpQuantizer,
QuantizerBase,
ScaleConfig,
)

if TYPE_CHECKING:
import onnx


class QuantizeSqueeze(QuantizerBase):
OP_TYPE = "Squeeze"
DOMAIN = ""
USE_WB = False
USE_SCALING = False
# Only the data input is relevant for scale-planning.
SCALE_PLAN: ClassVar = {0: 1}


class SqueezeQuantizer(BaseOpQuantizer, QuantizeSqueeze):
"""
Quantizer for ONNX Squeeze.

Squeeze is scale-preserving (pure shape/view transform):
- No arithmetic
- No rescaling
- No custom op

We support:
- axes as an attribute (older opsets)
- axes as a constant initializer input (newer opsets)

We do NOT support dynamic axes provided at runtime.
"""

def __init__(
self: SqueezeQuantizer,
new_initializers: list[onnx.TensorProto] | None = None,
) -> None:
super().__init__()
if new_initializers is not None:
self.new_initializers = new_initializers

def quantize(
self: SqueezeQuantizer,
node: onnx.NodeProto,
graph: onnx.GraphProto,
scale_config: ScaleConfig,
initializer_map: dict[str, onnx.TensorProto],
) -> list[onnx.NodeProto]:
# Pure passthrough; QuantizerBase handles standard bookkeeping.
return QuantizeSqueeze.quantize(
self,
node,
graph,
scale_config,
initializer_map,
)

_N_INPUTS_NO_AXES: ClassVar[int] = 1
_N_INPUTS_WITH_AXES: ClassVar[int] = 2

def _get_axes_from_attribute(self, node: onnx.NodeProto) -> list[int] | None:
for attr in node.attribute:
if attr.name == "axes":
return list(attr.ints)
return None

def _get_axes_from_initializer_input(
self,
node: onnx.NodeProto,
initializer_map: dict[str, onnx.TensorProto],
) -> list[int]:
axes_name = node.input[1]
if axes_name not in initializer_map:
raise InvalidParamError(
node_name=node.name,
op_type=node.op_type,
message=(
f"Dynamic axes input is not supported for Squeeze "
f"(expected axes '{axes_name}' to be an initializer)."
),
)

axes_tensor = initializer_map[axes_name]
arr = numpy_helper.to_array(axes_tensor)

if not np.issubdtype(arr.dtype, np.integer):
raise InvalidParamError(
node_name=node.name,
op_type=node.op_type,
message=f"Squeeze axes initializer must be integer, got {arr.dtype}.",
attr_key="axes",
expected="integer tensor (0-D or 1-D)",
)

if arr.ndim == 0:
return [int(arr)]
if arr.ndim == 1:
return [int(x) for x in arr.tolist()]

raise InvalidParamError(
node_name=node.name,
op_type=node.op_type,
message=f"Squeeze axes initializer must be 0-D or 1-D, got {arr.ndim}-D.",
attr_key="axes",
expected="0-D scalar or 1-D list of axes",
)

def check_supported(
self: SqueezeQuantizer,
node: onnx.NodeProto,
initializer_map: dict[str, onnx.TensorProto] | None = None,
) -> None:
self.validate_node_has_output(node)
initializer_map = initializer_map or {}

n_inputs = len(node.input)
if n_inputs not in (self._N_INPUTS_NO_AXES, self._N_INPUTS_WITH_AXES):
raise InvalidParamError(
node_name=node.name,
op_type=node.op_type,
message=f"Squeeze expects 1 or 2 inputs, got {n_inputs}.",
)

axes = self._get_axes_from_attribute(node)

# If axes is provided as a second input, it must be a constant initializer.
if axes is None and n_inputs == self._N_INPUTS_WITH_AXES:
axes = self._get_axes_from_initializer_input(node, initializer_map)

# If axes is omitted entirely, ONNX semantics are "remove all dims of size 1".
# We can't validate legality here without rank/shape; defer to Rust.

if axes is None:
return

if len(set(axes)) != len(axes):
raise InvalidParamError(
node_name=node.name,
op_type=node.op_type,
message=f"axes must not contain duplicates: {axes}",
attr_key="axes",
expected="axes list with unique entries",
)
Loading
Loading