Skip to content

Conversation

@tmfreiberg
Copy link
Contributor

@tmfreiberg tmfreiberg commented Dec 19, 2025

Description

This PR adds full support for the ONNX Squeeze and Unsqueeze operators across the quantization, model-conversion, and circuit-building pipeline.

Both operators are treated as pure shape-transform / passthrough layers:

  • no arithmetic,
  • no rescaling,
  • no quantization,
  • no additional circuit constraints.

Key changes

  1. Quantizer support (Squeeze & Unsqueeze)

    • Implemented SqueezeQuantizer and UnsqueezeQuantizer as strict passthroughs.
    • The quantize step delegates to the base quantizer logic and does not modify tensor values.
    • This matches ONNX semantics: these ops only reshape metadata.
  2. Circuit layer wiring (Rust)

    • Added corresponding layer handlers in the circuit layer registry.
    • The circuit builder forwards inputs directly to outputs with no gates added.
  3. ONNXConverter: explicit Unsqueeze axis handling
    Additional validation logic was added to python/core/model_processing/converters/onnx_converter.py to correctly handle opset ≥ 13 Unsqueeze semantics.

    In opset ≥ 13:

    Unsqueeze(data, axes)
    

    where axes is provided as a second input, not an attribute.

    The converter now:

    • Detects Unsqueeze nodes where axes is not present in attributes

    • Enforces that:

      • exactly two inputs are provided
      • the axes input is a constant initializer
    • Extracts the axes values from the initializer and injects them into
      layer.params["axes"], so downstream circuit construction receives a
      uniform representation.

    This mirrors the existing handling of Squeeze, and ensures:

    • no runtime/dynamic axes are allowed (unsupported by the circuit backend)
    • consistent architecture serialization for both ops
    • clearer error messages when malformed ONNX graphs are encountered
  4. End-to-end test coverage

    • Added E2E tests covering:

      • Squeeze
      • Unsqueeze with axes provided via initializer
    • Verified full flow: quantize → compile → witness → prove → verify


Related Issue

Type of Change

  • Bug fix (non-breaking)
  • New feature (non-breaking)
  • Breaking change (fix/feature causing existing functionality to break)
  • Refactor (non-functional changes)
  • Documentation update

Checklist

  • Code follows project patterns
  • Tests added/updated (if applicable)
  • Documentation updated (if applicable)
  • Self-review of code
  • All tests pass locally
  • Linter passes locally

Deployment Notes

Additional Comments

  • Squeeze and Unsqueeze are intentionally treated as no-op arithmetic layers; all logic is confined to shape metadata handling.
  • Runtime-variable axes for Unsqueeze are explicitly rejected to avoid unsound circuit construction.
  • The implementation follows the existing design philosophy used for other shape-only ops.

Summary by CodeRabbit

  • New Features

    • End-to-end support for ONNX Squeeze and Unsqueeze, accepting axes via attribute or constant initializer and exposed in circuit layers.
  • Bug Fixes / Validation

    • Stronger validation and clearer errors for axes handling (rejects unsupported dynamic axes, non-integer/invalid dimensions, and duplicate axes).
  • Tests

    • Added comprehensive unit and end-to-end tests covering valid and error scenarios for both ops.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link

coderabbitai bot commented Dec 19, 2025

Walkthrough

Adds ONNX Squeeze and Unsqueeze quantizers and tests, registers them in the ONNX op quantizer, implements corresponding Rust circuit layers and an AXES constant, updates test initializer handling, and applies a minor formatting change in the ONNX converter get_weights signature.

Changes

Cohort / File(s) Summary
ONNX quantizers
python/core/model_processing/onnx_quantizer/layers/squeeze.py, python/core/model_processing/onnx_quantizer/layers/unsqueeze.py
New SqueezeQuantizer and UnsqueezeQuantizer: extract axes from attributes or constant initializer inputs, validate dtype/dim/uniqueness, raise InvalidParamError on invalid configs, implement passthrough quantize behavior (Unsqueeze also embeds axes during pre-analysis).
Quantizer registration
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py
Registers "Squeeze" and "Unsqueeze" with their quantizer instances in ONNXOpQuantizer.__init__.
ONNX converter (minor)
python/core/model_processing/converters/onnx_converter.py
Formatting-only change to get_weights signature/unpacking; no semantic behavior change.
Python tests
python/tests/onnx_quantizer_tests/layers/base.py, python/tests/onnx_quantizer_tests/layers/squeeze_config.py, python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py
Treats "axes" initializer as int64 in test helper; adds SqueezeConfigProvider and UnsqueezeConfigProvider with valid, error, and e2e specs covering attribute/initializer axes, duplicates, and dynamic-axis errors.
Rust layer registry & modules
rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs, rust/jstprove_circuits/src/circuit_functions/layers/mod.rs
Adds Squeeze and Unsqueeze variants to LayerKind, exposes pub mod squeeze and pub mod unsqueeze, and removes internal math module.
Rust Squeeze implementation
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs
New SqueezeLayer with axes normalization, validation, squeezed-shape computation, LayerOp build/apply logic, and error handling.
Rust Unsqueeze implementation
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs
New UnsqueezeLayer with axes normalization, validation, unsqueezed-shape computation, LayerOp build/apply logic, and error handling.
Rust constants
rust/jstprove_circuits/src/circuit_functions/utils/constants.rs
Adds public AXES constant ("axes").

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • Batchnorm layer #83 — Overlaps with formatting/weight extraction changes in python/core/model_processing/converters/onnx_converter.py.
  • Single layer tests #74 — Related additions for ONNX Squeeze/Unsqueeze quantizers and registration in the ONNX op quantizer.

Suggested reviewers

  • tmfreiberg

Poem

🐇
I nibble axes, tuck and play,
I squeeze and stretch the dims all day.
Python checks and Rust reshape,
Tensors snug in every shape.
Hop, hop, hooray for layers new!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: adding Squeeze and Unsqueeze operator support across quantizer and circuit builder, which aligns with all major file changes.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (1)

104-133: Minor: Redundant clone on line 112.

input_name is already a &String, so .clone() before .get() is unnecessary since HashMap::get accepts &Q where String: Borrow<Q>.

The core logic is correct: flattening and reshaping preserves element order, which is the correct behavior for Unsqueeze.

🔎 Proposed fix
         let input_name = get_input_name(&self.inputs, 0, LayerKind::Unsqueeze, INPUT)?;
         let layer_input = input
-            .get(&input_name.clone())
+            .get(input_name)
             .ok_or_else(|| LayerError::MissingInput {
                 layer: LayerKind::Unsqueeze,
                 name: input_name.clone(),
             })?
             .clone();
rust/jstprove_circuits/src/circuit_functions/layers/gemm.rs (2)

283-310: Consider reducing redundant clones in Freivalds path.

The Freivalds branch clones both input_array and weights_array twice: once for unconstrained_matrix_multiplication and again for freivalds_verify_matrix_product. Since freivalds_verify_matrix_product takes references, you could avoid the second clone by passing references to the already-converted dynamic arrays.

🔎 Proposed optimization to reduce clones
     if use_freivalds {
+        let input_dyn = input_array.clone().into_dyn();
+        let weights_dyn = weights_array.clone().into_dyn();
+
         let core_dyn = unconstrained_matrix_multiplication(
             api,
-            input_array.clone().into_dyn(),
-            weights_array.clone().into_dyn(),
+            input_dyn.clone(),
+            weights_dyn.clone(),
             layer_kind.clone(),
         )?;
 
         freivalds_verify_matrix_product(
             api,
-            &input_array.clone().into_dyn(),
-            &weights_array.clone().into_dyn(),
+            &input_dyn,
+            &weights_dyn,
             &core_dyn,
             layer_kind,
             freivalds_reps,
         )?;

357-396: Cost model looks reasonable; consider documenting the derivation.

The cost model for deciding Freivalds vs. full matmul is sensible. A few observations:

  1. The d == 0 guard (line 388-390) correctly handles degenerate cases where Freivalds would have no meaningful cost advantage.
  2. Using saturating_* arithmetic is the right approach to avoid panics on large dimensions.

Consider adding a brief inline comment or doc reference explaining where the cost formulas come from (e.g., counting mul + add operations in each approach) to help future maintainers verify or tune the heuristic.

📜 Review details

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fe3737c and d667508.

📒 Files selected for processing (21)
  • python/core/model_processing/converters/onnx_converter.py (2 hunks)
  • python/core/model_processing/onnx_quantizer/layers/squeeze.py (1 hunks)
  • python/core/model_processing/onnx_quantizer/layers/unsqueeze.py (1 hunks)
  • python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py (2 hunks)
  • python/tests/onnx_quantizer_tests/layers/base.py (1 hunks)
  • python/tests/onnx_quantizer_tests/layers/max_config.py (0 hunks)
  • python/tests/onnx_quantizer_tests/layers/squeeze_config.py (1 hunks)
  • python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/gadgets/mod.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/add.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/batchnorm.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/gemm.rs (11 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs (2 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/math.rs (0 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/mod.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/mul.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/sub.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (1 hunks)
💤 Files with no reviewable changes (2)
  • python/tests/onnx_quantizer_tests/layers/max_config.py
  • rust/jstprove_circuits/src/circuit_functions/layers/math.rs
🧰 Additional context used
🧬 Code graph analysis (12)
rust/jstprove_circuits/src/circuit_functions/layers/mul.rs (1)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (1)
  • matrix_hadamard_product (143-157)
rust/jstprove_circuits/src/circuit_functions/layers/sub.rs (1)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (1)
  • matrix_subtraction (183-197)
python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py (3)
python/core/model_processing/onnx_quantizer/exceptions.py (1)
  • InvalidParamError (27-67)
python/tests/onnx_quantizer_tests/layers/base.py (12)
  • e2e_test (247-248)
  • error_test (239-240)
  • valid_test (235-236)
  • BaseLayerConfigProvider (251-277)
  • LayerTestConfig (58-166)
  • description (175-177)
  • override_inputs (191-193)
  • override_initializer (187-189)
  • override_input_shapes (195-197)
  • override_output_shapes (199-201)
  • tags (215-217)
  • expects_error (203-213)
python/tests/onnx_quantizer_tests/layers/squeeze_config.py (3)
  • layer_name (21-22)
  • get_config (24-39)
  • get_test_specs (41-114)
rust/jstprove_circuits/src/circuit_functions/layers/batchnorm.rs (1)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (2)
  • matrix_addition (103-117)
  • matrix_hadamard_product (143-157)
rust/jstprove_circuits/src/circuit_functions/layers/add.rs (1)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (1)
  • matrix_addition (103-117)
rust/jstprove_circuits/src/circuit_functions/layers/gemm.rs (2)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (2)
  • freivalds_verify_matrix_product (449-563)
  • matrix_multiplication (294-338)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (2)
  • get_w_or_b (69-113)
  • get_param_or_default (283-305)
rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs (2)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (1)
  • build (135-161)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (1)
  • build (140-171)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (4)
python/core/circuits/errors.py (1)
  • CircuitError (6-32)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (3)
  • extract_params_and_expected_shape (176-208)
  • get_input_name (255-267)
  • get_param (223-240)
rust/jstprove_circuits/src/circuit_functions/utils/shaping.rs (1)
  • input (105-105)
python/core/model_processing/converters/onnx_converter.py (1)
  • ONNXLayer (69-108)
python/tests/onnx_quantizer_tests/layers/squeeze_config.py (3)
python/core/model_processing/onnx_quantizer/exceptions.py (1)
  • InvalidParamError (27-67)
python/tests/onnx_quantizer_tests/layers/base.py (15)
  • e2e_test (247-248)
  • error_test (239-240)
  • valid_test (235-236)
  • BaseLayerConfigProvider (251-277)
  • LayerTestConfig (58-166)
  • layer_name (260-261)
  • get_config (255-256)
  • get_test_specs (263-265)
  • description (175-177)
  • override_inputs (191-193)
  • override_input_shapes (195-197)
  • override_output_shapes (199-201)
  • tags (215-217)
  • build (223-231)
  • override_initializer (187-189)
python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py (3)
  • layer_name (21-22)
  • get_config (24-41)
  • get_test_specs (43-114)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (2)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (3)
  • extract_params_and_expected_shape (176-208)
  • get_input_name (255-267)
  • get_param (223-240)
rust/jstprove_circuits/src/circuit_functions/utils/shaping.rs (1)
  • input (105-105)
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py (2)
python/core/model_processing/onnx_quantizer/layers/squeeze.py (1)
  • SqueezeQuantizer (30-154)
python/core/model_processing/onnx_quantizer/layers/unsqueeze.py (1)
  • UnsqueezeQuantizer (30-178)
python/core/model_processing/converters/onnx_converter.py (2)
python/core/model_processing/errors.py (1)
  • LayerAnalysisError (78-92)
python/tests/onnx_quantizer_tests/test_base_layer.py (1)
  • initializer_map (55-59)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: End-to-End Testing (3.10)
  • GitHub Check: End-to-End Testing (3.11)
  • GitHub Check: Check Formatting and Linting
  • GitHub Check: End-to-End Testing (3.12)
🔇 Additional comments (46)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (1)

37-45: LGTM!

The freivalds_reps field is correctly added with proper documentation and a sensible default of 1. The serde attribute ensures backward compatibility when deserializing older JSON configurations that lack this field.

python/core/model_processing/converters/onnx_converter.py (2)

65-65: LGTM!

The constant _N_UNSQUEEZE_INPUTS clearly documents the expected input count for opset ≥13 Unsqueeze operations.


558-596: LGTM! Well-structured axes extraction logic.

The special-case handling for Unsqueeze with axes-as-input is thorough:

  • Validates input count matches the expected schema
  • Ensures axes is a constant initializer (rejecting dynamic axes)
  • Validates integer dtype for axes
  • Handles both scalar and vector axes representations
  • Provides clear error messages for each failure case

The implementation correctly bridges opset ≥13 behavior to the uniform downstream representation expected by the circuit builder.

python/core/model_processing/onnx_quantizer/layers/unsqueeze.py (4)

21-28: LGTM!

The QuantizeUnsqueeze configuration correctly defines this as a scale-preserving, passthrough operation with no weight/bias handling or scaling requirements.


30-72: LGTM!

The UnsqueezeQuantizer class is well-structured with proper initialization and delegation to the base quantizer for standard passthrough behavior.


76-121: LGTM! Robust axes extraction logic.

The helper methods correctly handle both axes sources:

  • _get_axes_from_attribute: Straightforward attribute extraction
  • _get_axes_from_initializer_input: Properly validates initializer presence, dtype, and dimensionality with clear error messages for each failure case

123-178: LGTM! Comprehensive validation logic.

The check_supported method thoroughly validates:

  • Input count (1 or 2, depending on schema style)
  • Axes presence from either attribute or initializer
  • Uniqueness of axes values

The validation logic correctly handles both older opset (axes as attribute) and newer opset (axes as initializer) schemas.

rust/jstprove_circuits/src/circuit_functions/layers/sub.rs (1)

9-9: LGTM!

The import path update aligns with the refactoring that consolidates linear algebra operations into the dedicated gadgets::linear_algebra module.

rust/jstprove_circuits/src/circuit_functions/layers/mul.rs (1)

9-9: LGTM!

The import path update aligns with the refactoring that consolidates linear algebra operations into the dedicated gadgets::linear_algebra module.

python/tests/onnx_quantizer_tests/layers/base.py (1)

101-101: LGTM!

Extending the int64 conversion to include "axes" initializers is correct and aligns with the ONNX Unsqueeze opset ≥13 behavior where axes are provided as integer initializers.

rust/jstprove_circuits/src/circuit_functions/gadgets/mod.rs (1)

5-5: LGTM!

The public exposure of the linear_algebra module enables the import path refactoring seen across the layer implementations.

rust/jstprove_circuits/src/circuit_functions/layers/add.rs (1)

9-9: LGTM!

The import path update aligns with the refactoring that consolidates linear algebra operations into the dedicated gadgets::linear_algebra module.

python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py (2)

34-38: LGTM!

The imports for SqueezeQuantizer and UnsqueezeQuantizer are correctly placed and follow the existing import ordering pattern.


97-98: LGTM!

The registration of SqueezeQuantizer and UnsqueezeQuantizer follows the established pattern of passing self.new_initializers to quantizers that handle initializer inputs (axes in this case).

rust/jstprove_circuits/src/circuit_functions/layers/mod.rs (1)

17-19: LGTM!

The new public module declarations for squeeze and unsqueeze are correctly placed and follow the existing visibility pattern for layer modules.

rust/jstprove_circuits/src/circuit_functions/layers/batchnorm.rs (1)

6-12: LGTM!

The import path update correctly reflects the refactoring of linear algebra functions from the layers module to gadgets::linear_algebra. The function signatures remain unchanged as confirmed by the relevant code snippets.

rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs (2)

22-23: LGTM!

The imports for SqueezeLayer and UnsqueezeLayer are correctly added following the existing import pattern.


149-150: LGTM!

The registry entries for Squeeze and Unsqueeze are correctly wired to their respective layer builders. The pattern matches existing layer registrations, and the builder signatures align with the LayerOp trait requirements as shown in the relevant code snippets.

python/tests/onnx_quantizer_tests/layers/squeeze_config.py (3)

1-14: LGTM!

The imports are well-organized and include all necessary dependencies for the test configuration provider.


17-39: LGTM!

The SqueezeConfigProvider class and get_config method are well-structured. The base configuration correctly represents the newer opset form where axes is provided as an int64 initializer input.


41-114: LGTM!

The test specifications provide comprehensive coverage:

  • Valid tests cover axes-omitted semantics, basic initializer usage, singleton axis, and negative axis handling
  • Error tests verify rejection of duplicate axes and unsupported dynamic axes input
  • E2E tests ensure full pipeline verification

The test case logic is correct (e.g., axis -2 on a rank-3 tensor resolves to axis 1).

python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py (3)

1-14: LGTM!

The imports are consistent with the squeeze_config.py pattern and include all necessary dependencies.


17-41: LGTM!

The UnsqueezeConfigProvider class correctly configures the newer opset form where axes is provided as an int64 initializer. The output shape [1, 3, 1, 5] correctly reflects inserting dimensions at axes 0 and 2 into input shape [3, 5].


43-114: LGTM!

The test specifications provide solid coverage for Unsqueeze:

  • Valid tests cover various axis configurations including single, multiple, and negative axes
  • Error tests correctly verify rejection of duplicate and dynamic axes
  • The absence of an "axes_omitted" test is appropriate since Unsqueeze requires axes (unlike Squeeze which can default to removing all size-1 dims)

Output shape calculations are correct for all cases.

rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (4)

1-25: LGTM!

The imports and struct definition are well-organized. The UnsqueezeLayer struct correctly captures all necessary fields for the operation, and the #[allow(dead_code)] attribute is appropriate for fields used only in error messages.


27-61: LGTM!

The normalize_axes method correctly handles:

  • Negative axis conversion using output rank
  • Range validation (0 ≤ axis < rank_out)
  • Duplicate detection via HashSet
  • Returns sorted axes for predictable shape construction

The error messages are descriptive and include context.


63-101: LGTM!

The unsqueezed_shape method correctly implements ONNX Unsqueeze semantics:

  • Computes output rank as rank_in + axes.len()
  • Inserts 1s at specified axes positions
  • Validates all input dimensions are consumed
  • Error handling covers edge cases like running out of input dims

135-161: LGTM!

The build method correctly extracts layer metadata and requires the axes parameter (via get_param which errors if missing). This aligns with ONNX Unsqueeze semantics where axes is mandatory. The pattern is consistent with other layer implementations shown in the relevant code snippets.

rust/jstprove_circuits/src/circuit_functions/layers/gemm.rs (3)

71-90: LGTM! Clean struct definition with well-documented fields.

The GemmLayer struct is well-organized. The transition from v_plus_one to source_scale_exponent and the addition of freivalds_reps aligns with the PR objectives.


150-159: LGTM! Core product computation and bias addition are well-structured.

The flow of computing the core product (with optional Freivalds verification) followed by constrained bias addition is correct. The separation of concerns between compute_core_product and bias handling is clean.


200-209: Good defensive validation for freivalds_reps.

Build-time validation ensuring freivalds_reps != 0 prevents accidentally disabling verification. This is correctly duplicated from the runtime check in should_use_freivalds for fail-fast behavior.

rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (5)

17-25: LGTM! Clean struct definition for SqueezeLayer.

The struct captures all necessary fields for the Squeeze operation. The axes: Option<Vec<i64>> correctly models ONNX semantics where axes can be omitted.


27-65: LGTM! Robust axis normalization with proper validation.

The normalize_axes method correctly handles:

  • Negative axis indices (Python/ONNX convention)
  • Out-of-range validation
  • Duplicate detection
  • Consistent sorted output

The error messages are informative and include context (axis value, rank, axes list).


67-106: LGTM! Correct ONNX Squeeze shape inference.

The squeezed_shape implementation correctly handles both modes:

  1. Axes omitted → remove all dimensions of size 1
  2. Axes specified → validate and remove only those dimensions

The validation that specified axes must have dim == 1 is correct per ONNX spec.


109-138: LGTM! Correct passthrough implementation.

The apply method correctly implements Squeeze as a pure shape transform:

  • No circuit constraints are added (as expected for a view/reshape operation)
  • Element order is preserved via flatten → reshape
  • Error handling for invalid reshapes is present

154-160: LGTM! Correct handling of optional axes parameter.

The conditional check for params.get(AXES) before calling get_param correctly handles both:

  • Axes omitted (older ONNX semantics or "remove all size-1 dims")
  • Axes present as attribute/parameter
python/core/model_processing/onnx_quantizer/layers/squeeze.py (4)

21-28: LGTM! Clean mixin defining Squeeze characteristics.

The QuantizeSqueeze mixin correctly specifies that Squeeze:

  • Has no weights/biases (USE_WB = False)
  • Requires no scaling (USE_SCALING = False)
  • Only the data input (index 0) is relevant for scale planning

46-68: LGTM! Clean passthrough quantization.

The quantize method correctly delegates to the base implementation, reflecting that Squeeze is a pure shape transform requiring no special quantization logic.


79-118: LGTM! Robust initializer parsing with comprehensive validation.

The _get_axes_from_initializer_input method handles:

  • Missing initializer (dynamic axes) → clear error
  • Non-integer dtype → validation error
  • 0-D scalar → converts to single-element list
  • 1-D array → converts to list
  • 2D+ arrays → rejected

Error messages include useful context (node_name, op_type, expected values).


120-154: LGTM! Well-structured validation with appropriate deferrals.

The check_supported method:

  1. Validates structural constraints (input count, axes source)
  2. Rejects dynamic axes (not in initializer_map)
  3. Validates axes uniqueness
  4. Correctly defers shape-dependent validation (e.g., "is axis actually squeezable?") to the Rust circuit layer

This separation of concerns is clean and matches the PR design.

rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (6)

1-17: LGTM! Excellent module documentation.

The module-level documentation clearly explains:

  • The distinction between constrained and unconstrained gadgets
  • The security implications of unconstrained operations
  • The purpose of Freivalds verification
  • Error handling via LayerError / LayerKind

47-77: LGTM! Correct dot product implementation.

The dot function correctly validates inputs are 1D vectors of equal length and computes the sum of element-wise products using constrained operations.

Note: The #[allow(dead_code)] annotation suggests this may be unused currently. Consider removing it if/when it's integrated into a layer.


223-269: LGTM! Clean factorization of elementwise operations.

The elementwise_op helper effectively factors out common logic:

  • Shape matching with automatic reshape when element counts match
  • Consistent error handling across add/mul/sub
  • Clean use of function references for the operation

The reshape-by-element-count behavior is documented and useful for bias addition scenarios.


294-338: LGTM! Standard constrained matrix multiplication.

The implementation correctly:

  • Validates both inputs are 2D
  • Checks inner dimension compatibility (A.cols == B.rows)
  • Computes the product using constrained mul and add operations

344-403: LGTM! Well-documented unconstrained multiplication with clear security warning.

The documentation at lines 349-353 clearly warns that:

  • This adds no constraints relating C to A and B
  • The result is only a witness suggestion
  • It must be linked back via freivalds_verify_matrix_product or full constraints

This is critical information for maintainers and is appropriately highlighted.


449-563: Correct Freivalds implementation; soundness depends on get_random_value() framework semantics.

The implementation correctly follows the standard Freivalds algorithm: for each repetition (line 515-560), it samples a random vector (line 519), computes B·x, C·x, and A·(B·x), then asserts equality. Matrix dimension validation and num_repetitions enforcement are thorough.

The critical dependency is that api.get_random_value() (line 519) provides cryptographically-suitable, prover-unbiased randomness. In zero-knowledge contexts, if the prover could influence randomness after seeing A, B, and C, the soundness bound would degrade. Confirm that the expander_compiler framework's RootAPI trait ensures randomness is either: (1) derived from a random oracle independent of circuit inputs, or (2) generated from a verifier challenge committed before proof computation.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (1)

17-25: Clarify the purpose of input_shape and the dead_code attribute.

The input_shape field is populated during construction but never used in the apply method, which instead recomputes the shape from the actual input tensor. If input_shape serves no validation or debugging purpose, consider removing it to simplify the struct. Additionally, the #[allow(dead_code)] attribute suggests the compiler is flagging unused code—confirm whether all fields are intentionally retained.

📜 Review details

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5e38482 and f6f53e0.

📒 Files selected for processing (1)
  • rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (2)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (3)
  • extract_params_and_expected_shape (176-208)
  • get_input_name (255-267)
  • get_param (223-240)
rust/jstprove_circuits/src/circuit_functions/utils/shaping.rs (1)
  • input (105-105)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Check Formatting and Linting
  • GitHub Check: End-to-End Testing (3.10)
  • GitHub Check: End-to-End Testing (3.11)
  • GitHub Check: End-to-End Testing (3.12)
🔇 Additional comments (5)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (5)

28-61: LGTM: Axis normalization is correct.

The implementation properly handles negative indices, validates bounds, detects duplicates, and sorts the normalized axes. This aligns with ONNX Unsqueeze semantics.


63-101: LGTM: Shape computation is correct.

The method correctly computes the output shape by inserting dimensions of size 1 at the specified axes while preserving the order of input dimensions. The validation that all input dimensions are consumed adds robustness.


105-133: LGTM: Passthrough reshape implementation is correct.

The apply method correctly implements unsqueeze as a pure shape transformation—flattening the input and reshaping to the computed output shape while preserving element order. No gates are added, which aligns with the design goal of treating Unsqueeze as a passthrough layer.


1-15: LGTM: Imports are clean and appropriate.

All imports are utilized and follow the standard patterns for this codebase.


135-161: The Rust build implementation is correct for opset ≥13 Unsqueeze operations.

The Python converter properly normalizes axes into params["axes"] whether axes comes as an attribute or constant initializer, handling both scalar and array forms. The get_param() call in the Rust code correctly receives this normalized parameter, ensuring consistent behavior across all ONNX opset versions.

@tmfreiberg tmfreiberg changed the title Squeeze Add Squeeze and Unsqueeze operator support across quantizer and circuit builder Dec 19, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (2)

76-117: Consider using Option<&[i64]> for more idiomatic Rust.

The squeeze logic is correct and properly validates that specified axes have dimension size 1 before squeezing. However, the parameter type Option<&Vec<i64>> is considered non-idiomatic in Rust.

🔎 Proposed refactor

Change the signature to use Option<&[i64]>:

 fn squeezed_shape(
     &self,
     input_shape: &[usize],
-    axes: Option<&Vec<i64>>,
+    axes: Option<&[i64]>,
 ) -> Result<Vec<usize>, CircuitError> {

Then at line 91, simplify:

-let axes_ref = axes.expect("axes is Some here");
-let axes_u = self.normalize_axes(axes_ref.as_slice(), rank)?;
+let axes_slice = axes.unwrap(); // Safe: already checked is_none() above
+let axes_u = self.normalize_axes(axes_slice, rank)?;

And at the call site (line 136), change to:

-let out_shape = self.squeezed_shape(&in_shape, self.axes.as_ref())?;
+let out_shape = self.squeezed_shape(&in_shape, self.axes.as_deref())?;

121-149: Remove redundant clone on line 128.

The apply method correctly implements pure shape transformation with no circuit constraints, matching the PR objectives. However, there's a minor inefficiency.

🔎 Proposed fix
 let layer_input = input
-    .get(&input_name.clone())
+    .get(input_name)
     .ok_or_else(|| LayerError::MissingInput {

The get() method can accept &String directly without cloning. This matches the pattern in unsqueeze.rs line 125.

📜 Review details

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f6f53e0 and 842ddd5.

📒 Files selected for processing (2)
  • rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (1 hunks)
  • rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs
🧰 Additional context used
🧬 Code graph analysis (1)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (3)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (3)
  • extract_params_and_expected_shape (176-208)
  • get_input_name (255-267)
  • get_param (223-240)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (3)
  • normalize_axes (28-75)
  • apply (119-147)
  • build (149-175)
rust/jstprove_circuits/src/circuit_functions/utils/shaping.rs (1)
  • input (105-105)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Check Formatting and Linting
  • GitHub Check: End-to-End Testing (3.11)
  • GitHub Check: End-to-End Testing (3.12)
  • GitHub Check: End-to-End Testing (3.10)
🔇 Additional comments (4)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (4)

1-15: LGTM!

The imports and constant definition are appropriate and follow the established patterns in other layer implementations.


17-25: LGTM!

The struct definition correctly models Squeeze semantics with optional axes (None → squeeze all size-1 dimensions, Some → squeeze specified dimensions).


28-74: LGTM!

The axis normalization logic correctly handles negative indices, validates bounds, detects duplicates, and provides clear error messages. The implementation follows the same pattern as UnsqueezeLayer.


151-182: LGTM!

The build method correctly handles the optional axes parameter, which is essential for ONNX Squeeze semantics (when axes are omitted, all size-1 dimensions are removed). The conditional extraction pattern at lines 168-171 is appropriate.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
python/core/model_processing/onnx_quantizer/layers/squeeze.py (1)

46-52: Consider always initializing new_initializers for clarity.

The conditional initialization on lines 51-52 only sets self.new_initializers when the parameter is not None. If external code accesses this attribute without using hasattr(), this could raise AttributeError. For clarity and defensive programming, consider always initializing:

self.new_initializers = new_initializers if new_initializers is not None else []

This ensures the attribute always exists with a predictable type.

📜 Review details

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 842ddd5 and 30e86a9.

📒 Files selected for processing (2)
  • python/core/model_processing/onnx_quantizer/layers/squeeze.py (1 hunks)
  • python/tests/onnx_quantizer_tests/layers/squeeze_config.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • python/tests/onnx_quantizer_tests/layers/squeeze_config.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: End-to-End Testing (3.12)
  • GitHub Check: End-to-End Testing (3.11)
  • GitHub Check: End-to-End Testing (3.10)
  • GitHub Check: Check Formatting and Linting
🔇 Additional comments (7)
python/core/model_processing/onnx_quantizer/layers/squeeze.py (7)

1-19: LGTM: Clean imports and proper TYPE_CHECKING usage.

The imports are well-organized, and using TYPE_CHECKING for the onnx import avoids runtime circular dependencies.


21-28: LGTM: Correct passthrough configuration.

The configuration correctly marks Squeeze as a scale-preserving operation with no arithmetic or rescaling. The SCALE_PLAN appropriately identifies only the data input (index 0) as relevant for scale planning.


54-68: LGTM: Clean delegation to base quantizer.

The method correctly delegates to the base QuantizeSqueeze.quantize for passthrough behavior, maintaining consistency with the design philosophy for shape-only operations.


70-71: LGTM: Clear constant definitions.

The constants clearly distinguish between the two input configurations for Squeeze, improving code readability.


73-77: LGTM: Straightforward attribute extraction.

The method correctly extracts the axes attribute when present and returns None for the optional parameter case.


79-118: LGTM: Robust axes extraction with comprehensive validation.

The method correctly:

  • Rejects dynamic (runtime-variable) axes as required by the circuit backend
  • Validates integer dtype and proper dimensionality (0-D or 1-D)
  • Provides clear, actionable error messages

The explicit int() conversion on line 110 is technically redundant since arr.tolist() already returns Python ints for integer arrays, but it's defensive and ensures type consistency, so it's acceptable.


120-155: LGTM: Comprehensive validation logic.

The validation correctly handles all supported Squeeze configurations:

  • Axes via attribute (older opsets)
  • Axes via constant initializer (newer opsets)
  • Omitted axes (ONNX default: squeeze all size-1 dimensions)

The duplicate axes check is properly placed after the early return, and the comment on lines 142-143 clearly explains why shape validation is deferred to the Rust backend. The error messages throughout are clear and actionable.

Copy link
Collaborator

@jsgold-1 jsgold-1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall! Just a couple points. I think we need to move some of the code around a bit, but overall looks okay. Happy to chat about where we can move things to

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@python/core/model_processing/converters/onnx_converter.py`:
- Around line 74-135: _extract_unsqueeze_axes_into_params currently treats any
non-initializer axes input as dynamic and raises LayerAnalysisError; change it
to accept Constant-node axes by adding a new parameter model_graph_nodes (list
of NodeProto), and after building initializer_map if axes_name not in it, search
model_graph_nodes for a node with op_type=="Constant" and output name ==
axes_name, extract its "value" attribute (value_attr.t) and convert to numpy
with numpy_helper.to_array(value_attr.t) to populate axes_arr; keep the existing
integer-dtype check and axes normalization logic, and update callers (e.g.,
_pre_transform_unsqueeze) to pass model.graph.node as model_graph_nodes when
calling _extract_unsqueeze_axes_into_params.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@python/core/model_processing/converters/onnx_converter.py`:
- Around line 664-667: Add an optional model parameter to analyze_layer (e.g.,
def analyze_layer(..., model: Optional[onnx.ModelProto] = None) and use model or
self.model inside it; update get_model_architecture to pass model=model when
calling analyze_layer. Change the pre-analysis transform invocation in
CONVERTER_PRE_ANALYSIS_TRANSFORMS usage from transform(node, params, self.model,
self.model_type) to transform(node, params, model or self.model,
self.model_type) so _pre_transform_unsqueeze and similar transforms extract
Constant/initializer axes from the analyzed graph. Ensure all internal
references in analyze_layer that previously used self.model now use the passed
model when present for backward compatibility.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@python/core/model_processing/onnx_quantizer/layers/unsqueeze.py`:
- Around line 212-220: Change the signature of
_extract_unsqueeze_axes_into_params to use graph: onnx.GraphProto instead of
onnx.ModelProto and remove the unused '# noqa: PLR0913' comment; update the
function definition for _extract_unsqueeze_axes_into_params (and any callers if
necessary) so the parameter type correctly reflects that the code accesses
graph.initializer and graph.node, and delete the redundant noqa directive from
that line.

Copy link
Collaborator

@jsgold-1 jsgold-1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tmfreiberg tmfreiberg merged commit 6531894 into main Jan 29, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants