-
Notifications
You must be signed in to change notification settings - Fork 13
Add Squeeze and Unsqueeze operator support across quantizer and circuit builder #88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds ONNX Squeeze and Unsqueeze quantizers and tests, registers them in the ONNX op quantizer, implements corresponding Rust circuit layers and an AXES constant, updates test initializer handling, and applies a minor formatting change in the ONNX converter get_weights signature. Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (1)
104-133: Minor: Redundant clone on line 112.
input_nameis already a&String, so.clone()before.get()is unnecessary sinceHashMap::getaccepts&QwhereString: Borrow<Q>.The core logic is correct: flattening and reshaping preserves element order, which is the correct behavior for Unsqueeze.
🔎 Proposed fix
let input_name = get_input_name(&self.inputs, 0, LayerKind::Unsqueeze, INPUT)?; let layer_input = input - .get(&input_name.clone()) + .get(input_name) .ok_or_else(|| LayerError::MissingInput { layer: LayerKind::Unsqueeze, name: input_name.clone(), })? .clone();rust/jstprove_circuits/src/circuit_functions/layers/gemm.rs (2)
283-310: Consider reducing redundant clones in Freivalds path.The Freivalds branch clones both
input_arrayandweights_arraytwice: once forunconstrained_matrix_multiplicationand again forfreivalds_verify_matrix_product. Sincefreivalds_verify_matrix_producttakes references, you could avoid the second clone by passing references to the already-converted dynamic arrays.🔎 Proposed optimization to reduce clones
if use_freivalds { + let input_dyn = input_array.clone().into_dyn(); + let weights_dyn = weights_array.clone().into_dyn(); + let core_dyn = unconstrained_matrix_multiplication( api, - input_array.clone().into_dyn(), - weights_array.clone().into_dyn(), + input_dyn.clone(), + weights_dyn.clone(), layer_kind.clone(), )?; freivalds_verify_matrix_product( api, - &input_array.clone().into_dyn(), - &weights_array.clone().into_dyn(), + &input_dyn, + &weights_dyn, &core_dyn, layer_kind, freivalds_reps, )?;
357-396: Cost model looks reasonable; consider documenting the derivation.The cost model for deciding Freivalds vs. full matmul is sensible. A few observations:
- The
d == 0guard (line 388-390) correctly handles degenerate cases where Freivalds would have no meaningful cost advantage.- Using
saturating_*arithmetic is the right approach to avoid panics on large dimensions.Consider adding a brief inline comment or doc reference explaining where the cost formulas come from (e.g., counting
mul+addoperations in each approach) to help future maintainers verify or tune the heuristic.
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
python/core/model_processing/converters/onnx_converter.py(2 hunks)python/core/model_processing/onnx_quantizer/layers/squeeze.py(1 hunks)python/core/model_processing/onnx_quantizer/layers/unsqueeze.py(1 hunks)python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py(2 hunks)python/tests/onnx_quantizer_tests/layers/base.py(1 hunks)python/tests/onnx_quantizer_tests/layers/max_config.py(0 hunks)python/tests/onnx_quantizer_tests/layers/squeeze_config.py(1 hunks)python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py(1 hunks)rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/gadgets/mod.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/layers/add.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/layers/batchnorm.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/layers/gemm.rs(11 hunks)rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs(2 hunks)rust/jstprove_circuits/src/circuit_functions/layers/math.rs(0 hunks)rust/jstprove_circuits/src/circuit_functions/layers/mod.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/layers/mul.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/layers/sub.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs(1 hunks)
💤 Files with no reviewable changes (2)
- python/tests/onnx_quantizer_tests/layers/max_config.py
- rust/jstprove_circuits/src/circuit_functions/layers/math.rs
🧰 Additional context used
🧬 Code graph analysis (12)
rust/jstprove_circuits/src/circuit_functions/layers/mul.rs (1)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (1)
matrix_hadamard_product(143-157)
rust/jstprove_circuits/src/circuit_functions/layers/sub.rs (1)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (1)
matrix_subtraction(183-197)
python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py (3)
python/core/model_processing/onnx_quantizer/exceptions.py (1)
InvalidParamError(27-67)python/tests/onnx_quantizer_tests/layers/base.py (12)
e2e_test(247-248)error_test(239-240)valid_test(235-236)BaseLayerConfigProvider(251-277)LayerTestConfig(58-166)description(175-177)override_inputs(191-193)override_initializer(187-189)override_input_shapes(195-197)override_output_shapes(199-201)tags(215-217)expects_error(203-213)python/tests/onnx_quantizer_tests/layers/squeeze_config.py (3)
layer_name(21-22)get_config(24-39)get_test_specs(41-114)
rust/jstprove_circuits/src/circuit_functions/layers/batchnorm.rs (1)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (2)
matrix_addition(103-117)matrix_hadamard_product(143-157)
rust/jstprove_circuits/src/circuit_functions/layers/add.rs (1)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (1)
matrix_addition(103-117)
rust/jstprove_circuits/src/circuit_functions/layers/gemm.rs (2)
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (2)
freivalds_verify_matrix_product(449-563)matrix_multiplication(294-338)rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (2)
get_w_or_b(69-113)get_param_or_default(283-305)
rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs (2)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (1)
build(135-161)rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (1)
build(140-171)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (4)
python/core/circuits/errors.py (1)
CircuitError(6-32)rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (3)
extract_params_and_expected_shape(176-208)get_input_name(255-267)get_param(223-240)rust/jstprove_circuits/src/circuit_functions/utils/shaping.rs (1)
input(105-105)python/core/model_processing/converters/onnx_converter.py (1)
ONNXLayer(69-108)
python/tests/onnx_quantizer_tests/layers/squeeze_config.py (3)
python/core/model_processing/onnx_quantizer/exceptions.py (1)
InvalidParamError(27-67)python/tests/onnx_quantizer_tests/layers/base.py (15)
e2e_test(247-248)error_test(239-240)valid_test(235-236)BaseLayerConfigProvider(251-277)LayerTestConfig(58-166)layer_name(260-261)get_config(255-256)get_test_specs(263-265)description(175-177)override_inputs(191-193)override_input_shapes(195-197)override_output_shapes(199-201)tags(215-217)build(223-231)override_initializer(187-189)python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py (3)
layer_name(21-22)get_config(24-41)get_test_specs(43-114)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (2)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (3)
extract_params_and_expected_shape(176-208)get_input_name(255-267)get_param(223-240)rust/jstprove_circuits/src/circuit_functions/utils/shaping.rs (1)
input(105-105)
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py (2)
python/core/model_processing/onnx_quantizer/layers/squeeze.py (1)
SqueezeQuantizer(30-154)python/core/model_processing/onnx_quantizer/layers/unsqueeze.py (1)
UnsqueezeQuantizer(30-178)
python/core/model_processing/converters/onnx_converter.py (2)
python/core/model_processing/errors.py (1)
LayerAnalysisError(78-92)python/tests/onnx_quantizer_tests/test_base_layer.py (1)
initializer_map(55-59)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: End-to-End Testing (3.10)
- GitHub Check: End-to-End Testing (3.11)
- GitHub Check: Check Formatting and Linting
- GitHub Check: End-to-End Testing (3.12)
🔇 Additional comments (46)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (1)
37-45: LGTM!The
freivalds_repsfield is correctly added with proper documentation and a sensible default of 1. The serde attribute ensures backward compatibility when deserializing older JSON configurations that lack this field.python/core/model_processing/converters/onnx_converter.py (2)
65-65: LGTM!The constant
_N_UNSQUEEZE_INPUTSclearly documents the expected input count for opset ≥13 Unsqueeze operations.
558-596: LGTM! Well-structured axes extraction logic.The special-case handling for Unsqueeze with axes-as-input is thorough:
- Validates input count matches the expected schema
- Ensures axes is a constant initializer (rejecting dynamic axes)
- Validates integer dtype for axes
- Handles both scalar and vector axes representations
- Provides clear error messages for each failure case
The implementation correctly bridges opset ≥13 behavior to the uniform downstream representation expected by the circuit builder.
python/core/model_processing/onnx_quantizer/layers/unsqueeze.py (4)
21-28: LGTM!The
QuantizeUnsqueezeconfiguration correctly defines this as a scale-preserving, passthrough operation with no weight/bias handling or scaling requirements.
30-72: LGTM!The
UnsqueezeQuantizerclass is well-structured with proper initialization and delegation to the base quantizer for standard passthrough behavior.
76-121: LGTM! Robust axes extraction logic.The helper methods correctly handle both axes sources:
_get_axes_from_attribute: Straightforward attribute extraction_get_axes_from_initializer_input: Properly validates initializer presence, dtype, and dimensionality with clear error messages for each failure case
123-178: LGTM! Comprehensive validation logic.The
check_supportedmethod thoroughly validates:
- Input count (1 or 2, depending on schema style)
- Axes presence from either attribute or initializer
- Uniqueness of axes values
The validation logic correctly handles both older opset (axes as attribute) and newer opset (axes as initializer) schemas.
rust/jstprove_circuits/src/circuit_functions/layers/sub.rs (1)
9-9: LGTM!The import path update aligns with the refactoring that consolidates linear algebra operations into the dedicated
gadgets::linear_algebramodule.rust/jstprove_circuits/src/circuit_functions/layers/mul.rs (1)
9-9: LGTM!The import path update aligns with the refactoring that consolidates linear algebra operations into the dedicated
gadgets::linear_algebramodule.python/tests/onnx_quantizer_tests/layers/base.py (1)
101-101: LGTM!Extending the int64 conversion to include "axes" initializers is correct and aligns with the ONNX Unsqueeze opset ≥13 behavior where axes are provided as integer initializers.
rust/jstprove_circuits/src/circuit_functions/gadgets/mod.rs (1)
5-5: LGTM!The public exposure of the
linear_algebramodule enables the import path refactoring seen across the layer implementations.rust/jstprove_circuits/src/circuit_functions/layers/add.rs (1)
9-9: LGTM!The import path update aligns with the refactoring that consolidates linear algebra operations into the dedicated
gadgets::linear_algebramodule.python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py (2)
34-38: LGTM!The imports for
SqueezeQuantizerandUnsqueezeQuantizerare correctly placed and follow the existing import ordering pattern.
97-98: LGTM!The registration of
SqueezeQuantizerandUnsqueezeQuantizerfollows the established pattern of passingself.new_initializersto quantizers that handle initializer inputs (axes in this case).rust/jstprove_circuits/src/circuit_functions/layers/mod.rs (1)
17-19: LGTM!The new public module declarations for
squeezeandunsqueezeare correctly placed and follow the existing visibility pattern for layer modules.rust/jstprove_circuits/src/circuit_functions/layers/batchnorm.rs (1)
6-12: LGTM!The import path update correctly reflects the refactoring of linear algebra functions from the layers module to
gadgets::linear_algebra. The function signatures remain unchanged as confirmed by the relevant code snippets.rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs (2)
22-23: LGTM!The imports for
SqueezeLayerandUnsqueezeLayerare correctly added following the existing import pattern.
149-150: LGTM!The registry entries for
SqueezeandUnsqueezeare correctly wired to their respective layer builders. The pattern matches existing layer registrations, and the builder signatures align with theLayerOptrait requirements as shown in the relevant code snippets.python/tests/onnx_quantizer_tests/layers/squeeze_config.py (3)
1-14: LGTM!The imports are well-organized and include all necessary dependencies for the test configuration provider.
17-39: LGTM!The
SqueezeConfigProviderclass andget_configmethod are well-structured. The base configuration correctly represents the newer opset form where axes is provided as an int64 initializer input.
41-114: LGTM!The test specifications provide comprehensive coverage:
- Valid tests cover axes-omitted semantics, basic initializer usage, singleton axis, and negative axis handling
- Error tests verify rejection of duplicate axes and unsupported dynamic axes input
- E2E tests ensure full pipeline verification
The test case logic is correct (e.g., axis
-2on a rank-3 tensor resolves to axis 1).python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py (3)
1-14: LGTM!The imports are consistent with the squeeze_config.py pattern and include all necessary dependencies.
17-41: LGTM!The
UnsqueezeConfigProviderclass correctly configures the newer opset form where axes is provided as an int64 initializer. The output shape[1, 3, 1, 5]correctly reflects inserting dimensions at axes 0 and 2 into input shape[3, 5].
43-114: LGTM!The test specifications provide solid coverage for Unsqueeze:
- Valid tests cover various axis configurations including single, multiple, and negative axes
- Error tests correctly verify rejection of duplicate and dynamic axes
- The absence of an "axes_omitted" test is appropriate since Unsqueeze requires axes (unlike Squeeze which can default to removing all size-1 dims)
Output shape calculations are correct for all cases.
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (4)
1-25: LGTM!The imports and struct definition are well-organized. The
UnsqueezeLayerstruct correctly captures all necessary fields for the operation, and the#[allow(dead_code)]attribute is appropriate for fields used only in error messages.
27-61: LGTM!The
normalize_axesmethod correctly handles:
- Negative axis conversion using output rank
- Range validation (0 ≤ axis < rank_out)
- Duplicate detection via
HashSet- Returns sorted axes for predictable shape construction
The error messages are descriptive and include context.
63-101: LGTM!The
unsqueezed_shapemethod correctly implements ONNX Unsqueeze semantics:
- Computes output rank as
rank_in + axes.len()- Inserts 1s at specified axes positions
- Validates all input dimensions are consumed
- Error handling covers edge cases like running out of input dims
135-161: LGTM!The
buildmethod correctly extracts layer metadata and requires theaxesparameter (viaget_paramwhich errors if missing). This aligns with ONNX Unsqueeze semantics where axes is mandatory. The pattern is consistent with other layer implementations shown in the relevant code snippets.rust/jstprove_circuits/src/circuit_functions/layers/gemm.rs (3)
71-90: LGTM! Clean struct definition with well-documented fields.The
GemmLayerstruct is well-organized. The transition fromv_plus_onetosource_scale_exponentand the addition offreivalds_repsaligns with the PR objectives.
150-159: LGTM! Core product computation and bias addition are well-structured.The flow of computing the core product (with optional Freivalds verification) followed by constrained bias addition is correct. The separation of concerns between
compute_core_productand bias handling is clean.
200-209: Good defensive validation forfreivalds_reps.Build-time validation ensuring
freivalds_reps != 0prevents accidentally disabling verification. This is correctly duplicated from the runtime check inshould_use_freivaldsfor fail-fast behavior.rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (5)
17-25: LGTM! Clean struct definition for SqueezeLayer.The struct captures all necessary fields for the Squeeze operation. The
axes: Option<Vec<i64>>correctly models ONNX semantics where axes can be omitted.
27-65: LGTM! Robust axis normalization with proper validation.The
normalize_axesmethod correctly handles:
- Negative axis indices (Python/ONNX convention)
- Out-of-range validation
- Duplicate detection
- Consistent sorted output
The error messages are informative and include context (axis value, rank, axes list).
67-106: LGTM! Correct ONNX Squeeze shape inference.The
squeezed_shapeimplementation correctly handles both modes:
- Axes omitted → remove all dimensions of size 1
- Axes specified → validate and remove only those dimensions
The validation that specified axes must have
dim == 1is correct per ONNX spec.
109-138: LGTM! Correct passthrough implementation.The
applymethod correctly implements Squeeze as a pure shape transform:
- No circuit constraints are added (as expected for a view/reshape operation)
- Element order is preserved via flatten → reshape
- Error handling for invalid reshapes is present
154-160: LGTM! Correct handling of optional axes parameter.The conditional check for
params.get(AXES)before callingget_paramcorrectly handles both:
- Axes omitted (older ONNX semantics or "remove all size-1 dims")
- Axes present as attribute/parameter
python/core/model_processing/onnx_quantizer/layers/squeeze.py (4)
21-28: LGTM! Clean mixin defining Squeeze characteristics.The
QuantizeSqueezemixin correctly specifies that Squeeze:
- Has no weights/biases (
USE_WB = False)- Requires no scaling (
USE_SCALING = False)- Only the data input (index 0) is relevant for scale planning
46-68: LGTM! Clean passthrough quantization.The
quantizemethod correctly delegates to the base implementation, reflecting that Squeeze is a pure shape transform requiring no special quantization logic.
79-118: LGTM! Robust initializer parsing with comprehensive validation.The
_get_axes_from_initializer_inputmethod handles:
- Missing initializer (dynamic axes) → clear error
- Non-integer dtype → validation error
- 0-D scalar → converts to single-element list
- 1-D array → converts to list
- 2D+ arrays → rejected
Error messages include useful context (
node_name,op_type, expected values).
120-154: LGTM! Well-structured validation with appropriate deferrals.The
check_supportedmethod:
- Validates structural constraints (input count, axes source)
- Rejects dynamic axes (not in initializer_map)
- Validates axes uniqueness
- Correctly defers shape-dependent validation (e.g., "is axis actually squeezable?") to the Rust circuit layer
This separation of concerns is clean and matches the PR design.
rust/jstprove_circuits/src/circuit_functions/gadgets/linear_algebra.rs (6)
1-17: LGTM! Excellent module documentation.The module-level documentation clearly explains:
- The distinction between constrained and unconstrained gadgets
- The security implications of unconstrained operations
- The purpose of Freivalds verification
- Error handling via
LayerError/LayerKind
47-77: LGTM! Correct dot product implementation.The
dotfunction correctly validates inputs are 1D vectors of equal length and computes the sum of element-wise products using constrained operations.Note: The
#[allow(dead_code)]annotation suggests this may be unused currently. Consider removing it if/when it's integrated into a layer.
223-269: LGTM! Clean factorization of elementwise operations.The
elementwise_ophelper effectively factors out common logic:
- Shape matching with automatic reshape when element counts match
- Consistent error handling across add/mul/sub
- Clean use of function references for the operation
The reshape-by-element-count behavior is documented and useful for bias addition scenarios.
294-338: LGTM! Standard constrained matrix multiplication.The implementation correctly:
- Validates both inputs are 2D
- Checks inner dimension compatibility (
A.cols == B.rows)- Computes the product using constrained
mulandaddoperations
344-403: LGTM! Well-documented unconstrained multiplication with clear security warning.The documentation at lines 349-353 clearly warns that:
- This adds no constraints relating
CtoAandB- The result is only a witness suggestion
- It must be linked back via
freivalds_verify_matrix_productor full constraintsThis is critical information for maintainers and is appropriately highlighted.
449-563: Correct Freivalds implementation; soundness depends on get_random_value() framework semantics.The implementation correctly follows the standard Freivalds algorithm: for each repetition (line 515-560), it samples a random vector (line 519), computes B·x, C·x, and A·(B·x), then asserts equality. Matrix dimension validation and num_repetitions enforcement are thorough.
The critical dependency is that
api.get_random_value()(line 519) provides cryptographically-suitable, prover-unbiased randomness. In zero-knowledge contexts, if the prover could influence randomness after seeing A, B, and C, the soundness bound would degrade. Confirm that the expander_compiler framework's RootAPI trait ensures randomness is either: (1) derived from a random oracle independent of circuit inputs, or (2) generated from a verifier challenge committed before proof computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (1)
17-25: Clarify the purpose ofinput_shapeand the dead_code attribute.The
input_shapefield is populated during construction but never used in theapplymethod, which instead recomputes the shape from the actual input tensor. Ifinput_shapeserves no validation or debugging purpose, consider removing it to simplify the struct. Additionally, the#[allow(dead_code)]attribute suggests the compiler is flagging unused code—confirm whether all fields are intentionally retained.
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (2)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (3)
extract_params_and_expected_shape(176-208)get_input_name(255-267)get_param(223-240)rust/jstprove_circuits/src/circuit_functions/utils/shaping.rs (1)
input(105-105)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Check Formatting and Linting
- GitHub Check: End-to-End Testing (3.10)
- GitHub Check: End-to-End Testing (3.11)
- GitHub Check: End-to-End Testing (3.12)
🔇 Additional comments (5)
rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (5)
28-61: LGTM: Axis normalization is correct.The implementation properly handles negative indices, validates bounds, detects duplicates, and sorts the normalized axes. This aligns with ONNX Unsqueeze semantics.
63-101: LGTM: Shape computation is correct.The method correctly computes the output shape by inserting dimensions of size 1 at the specified axes while preserving the order of input dimensions. The validation that all input dimensions are consumed adds robustness.
105-133: LGTM: Passthrough reshape implementation is correct.The
applymethod correctly implements unsqueeze as a pure shape transformation—flattening the input and reshaping to the computed output shape while preserving element order. No gates are added, which aligns with the design goal of treating Unsqueeze as a passthrough layer.
1-15: LGTM: Imports are clean and appropriate.All imports are utilized and follow the standard patterns for this codebase.
135-161: The Rust build implementation is correct for opset ≥13 Unsqueeze operations.The Python converter properly normalizes axes into
params["axes"]whether axes comes as an attribute or constant initializer, handling both scalar and array forms. Theget_param()call in the Rust code correctly receives this normalized parameter, ensuring consistent behavior across all ONNX opset versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (2)
76-117: Consider usingOption<&[i64]>for more idiomatic Rust.The squeeze logic is correct and properly validates that specified axes have dimension size 1 before squeezing. However, the parameter type
Option<&Vec<i64>>is considered non-idiomatic in Rust.🔎 Proposed refactor
Change the signature to use
Option<&[i64]>:fn squeezed_shape( &self, input_shape: &[usize], - axes: Option<&Vec<i64>>, + axes: Option<&[i64]>, ) -> Result<Vec<usize>, CircuitError> {Then at line 91, simplify:
-let axes_ref = axes.expect("axes is Some here"); -let axes_u = self.normalize_axes(axes_ref.as_slice(), rank)?; +let axes_slice = axes.unwrap(); // Safe: already checked is_none() above +let axes_u = self.normalize_axes(axes_slice, rank)?;And at the call site (line 136), change to:
-let out_shape = self.squeezed_shape(&in_shape, self.axes.as_ref())?; +let out_shape = self.squeezed_shape(&in_shape, self.axes.as_deref())?;
121-149: Remove redundant clone on line 128.The apply method correctly implements pure shape transformation with no circuit constraints, matching the PR objectives. However, there's a minor inefficiency.
🔎 Proposed fix
let layer_input = input - .get(&input_name.clone()) + .get(input_name) .ok_or_else(|| LayerError::MissingInput {The
get()method can accept&Stringdirectly without cloning. This matches the pattern inunsqueeze.rsline 125.
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs(1 hunks)rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs
🧰 Additional context used
🧬 Code graph analysis (1)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (3)
rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (3)
extract_params_and_expected_shape(176-208)get_input_name(255-267)get_param(223-240)rust/jstprove_circuits/src/circuit_functions/layers/unsqueeze.rs (3)
normalize_axes(28-75)apply(119-147)build(149-175)rust/jstprove_circuits/src/circuit_functions/utils/shaping.rs (1)
input(105-105)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Check Formatting and Linting
- GitHub Check: End-to-End Testing (3.11)
- GitHub Check: End-to-End Testing (3.12)
- GitHub Check: End-to-End Testing (3.10)
🔇 Additional comments (4)
rust/jstprove_circuits/src/circuit_functions/layers/squeeze.rs (4)
1-15: LGTM!The imports and constant definition are appropriate and follow the established patterns in other layer implementations.
17-25: LGTM!The struct definition correctly models Squeeze semantics with optional axes (None → squeeze all size-1 dimensions, Some → squeeze specified dimensions).
28-74: LGTM!The axis normalization logic correctly handles negative indices, validates bounds, detects duplicates, and provides clear error messages. The implementation follows the same pattern as UnsqueezeLayer.
151-182: LGTM!The build method correctly handles the optional axes parameter, which is essential for ONNX Squeeze semantics (when axes are omitted, all size-1 dimensions are removed). The conditional extraction pattern at lines 168-171 is appropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
python/core/model_processing/onnx_quantizer/layers/squeeze.py (1)
46-52: Consider always initializingnew_initializersfor clarity.The conditional initialization on lines 51-52 only sets
self.new_initializerswhen the parameter is notNone. If external code accesses this attribute without usinghasattr(), this could raiseAttributeError. For clarity and defensive programming, consider always initializing:self.new_initializers = new_initializers if new_initializers is not None else []This ensures the attribute always exists with a predictable type.
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
python/core/model_processing/onnx_quantizer/layers/squeeze.py(1 hunks)python/tests/onnx_quantizer_tests/layers/squeeze_config.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- python/tests/onnx_quantizer_tests/layers/squeeze_config.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: End-to-End Testing (3.12)
- GitHub Check: End-to-End Testing (3.11)
- GitHub Check: End-to-End Testing (3.10)
- GitHub Check: Check Formatting and Linting
🔇 Additional comments (7)
python/core/model_processing/onnx_quantizer/layers/squeeze.py (7)
1-19: LGTM: Clean imports and proper TYPE_CHECKING usage.The imports are well-organized, and using
TYPE_CHECKINGfor theonnximport avoids runtime circular dependencies.
21-28: LGTM: Correct passthrough configuration.The configuration correctly marks Squeeze as a scale-preserving operation with no arithmetic or rescaling. The
SCALE_PLANappropriately identifies only the data input (index 0) as relevant for scale planning.
54-68: LGTM: Clean delegation to base quantizer.The method correctly delegates to the base
QuantizeSqueeze.quantizefor passthrough behavior, maintaining consistency with the design philosophy for shape-only operations.
70-71: LGTM: Clear constant definitions.The constants clearly distinguish between the two input configurations for Squeeze, improving code readability.
73-77: LGTM: Straightforward attribute extraction.The method correctly extracts the
axesattribute when present and returnsNonefor the optional parameter case.
79-118: LGTM: Robust axes extraction with comprehensive validation.The method correctly:
- Rejects dynamic (runtime-variable) axes as required by the circuit backend
- Validates integer dtype and proper dimensionality (0-D or 1-D)
- Provides clear, actionable error messages
The explicit
int()conversion on line 110 is technically redundant sincearr.tolist()already returns Python ints for integer arrays, but it's defensive and ensures type consistency, so it's acceptable.
120-155: LGTM: Comprehensive validation logic.The validation correctly handles all supported Squeeze configurations:
- Axes via attribute (older opsets)
- Axes via constant initializer (newer opsets)
- Omitted axes (ONNX default: squeeze all size-1 dimensions)
The duplicate axes check is properly placed after the early return, and the comment on lines 142-143 clearly explains why shape validation is deferred to the Rust backend. The error messages throughout are clear and actionable.
jsgold-1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall! Just a couple points. I think we need to move some of the code around a bit, but overall looks okay. Happy to chat about where we can move things to
…Max broadcast e2e test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@python/core/model_processing/converters/onnx_converter.py`:
- Around line 74-135: _extract_unsqueeze_axes_into_params currently treats any
non-initializer axes input as dynamic and raises LayerAnalysisError; change it
to accept Constant-node axes by adding a new parameter model_graph_nodes (list
of NodeProto), and after building initializer_map if axes_name not in it, search
model_graph_nodes for a node with op_type=="Constant" and output name ==
axes_name, extract its "value" attribute (value_attr.t) and convert to numpy
with numpy_helper.to_array(value_attr.t) to populate axes_arr; keep the existing
integer-dtype check and axes normalization logic, and update callers (e.g.,
_pre_transform_unsqueeze) to pass model.graph.node as model_graph_nodes when
calling _extract_unsqueeze_axes_into_params.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@python/core/model_processing/converters/onnx_converter.py`:
- Around line 664-667: Add an optional model parameter to analyze_layer (e.g.,
def analyze_layer(..., model: Optional[onnx.ModelProto] = None) and use model or
self.model inside it; update get_model_architecture to pass model=model when
calling analyze_layer. Change the pre-analysis transform invocation in
CONVERTER_PRE_ANALYSIS_TRANSFORMS usage from transform(node, params, self.model,
self.model_type) to transform(node, params, model or self.model,
self.model_type) so _pre_transform_unsqueeze and similar transforms extract
Constant/initializer axes from the analyzed graph. Ensure all internal
references in analyze_layer that previously used self.model now use the passed
model when present for backward compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@python/core/model_processing/onnx_quantizer/layers/unsqueeze.py`:
- Around line 212-220: Change the signature of
_extract_unsqueeze_axes_into_params to use graph: onnx.GraphProto instead of
onnx.ModelProto and remove the unused '# noqa: PLR0913' comment; update the
function definition for _extract_unsqueeze_axes_into_params (and any callers if
necessary) so the parameter type correctly reflects that the code accesses
graph.initializer and graph.node, and delete the redundant noqa directive from
that line.
jsgold-1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
This PR adds full support for the ONNX
SqueezeandUnsqueezeoperators across the quantization, model-conversion, and circuit-building pipeline.Both operators are treated as pure shape-transform / passthrough layers:
Key changes
Quantizer support (Squeeze & Unsqueeze)
SqueezeQuantizerandUnsqueezeQuantizeras strict passthroughs.Circuit layer wiring (Rust)
ONNXConverter: explicit Unsqueeze axis handling
Additional validation logic was added to
python/core/model_processing/converters/onnx_converter.pyto correctly handle opset ≥ 13 Unsqueeze semantics.In opset ≥ 13:
where
axesis provided as a second input, not an attribute.The converter now:
Detects
Unsqueezenodes whereaxesis not present in attributesEnforces that:
axesinput is a constant initializerExtracts the axes values from the initializer and injects them into
layer.params["axes"], so downstream circuit construction receives auniform representation.
This mirrors the existing handling of
Squeeze, and ensures:End-to-end test coverage
Added E2E tests covering:
SqueezeUnsqueezewith axes provided via initializerVerified full flow: quantize → compile → witness → prove → verify
Related Issue
Type of Change
Checklist
Deployment Notes
Additional Comments
SqueezeandUnsqueezeare intentionally treated as no-op arithmetic layers; all logic is confined to shape metadata handling.Unsqueezeare explicitly rejected to avoid unsound circuit construction.Summary by CodeRabbit
New Features
Bug Fixes / Validation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.