-
Notifications
You must be signed in to change notification settings - Fork 13
Add support for constant-divisor division #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds Div support across the codebase: a Python ONNX DivQuantizer (casts initializers to int64 and emits Div nodes), new Div test configs, a Rust DivLayer implementing fixed-point, range-checked elementwise division, layer registration, and a small build-config change. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In @python/core/model_processing/onnx_quantizer/layers/div.py:
- Around line 25-31: The constructor of DivQuantizer leaves
self.new_initializers undefined when new_initializers is None, causing an
AttributeError later where self.new_initializers.append(...) is called; fix by
always initializing self.new_initializers in __init__ (e.g., set
self.new_initializers = new_initializers or an empty list) so any subsequent
append operations on self.new_initializers in DivQuantizer methods are safe.
- Around line 44-51: Add explicit type validation before the astype(np.int64)
cast: after converting the initializer to a numpy array via
numpy_helper.to_array(tensor), check arr.dtype (or tensor.data_type) and if it's
a floating dtype, raise a clear TypeError (or log an error) describing the
unexpected float initializer for division operands and include the initializer
name; only perform the astype(np.int64) cast and append the new initializer when
the dtype is already integer. Apply the same validation change to the analogous
cast in batchnorm.py (lines around the existing astype calls).
In @rust/jstprove_circuits/src/circuit_functions/layers/div.rs:
- Around line 86-88: The closure mapping divisor values currently does an
unchecked cast `divisor_val as u32` which will truncate values > u32::MAX;
update the closure (used where you map over `(dividend, divisor_val)`) to
perform a bounds check on `divisor_val` (ensure 0 <= divisor_val <= u32::MAX)
and return a Result (e.g., Err for out-of-range) instead of silently casting,
then propagate errors by using `try_collect()` or explicit iteration/error
handling in the surrounding function (e.g., the function performing the division
in div.rs) so invalid large divisors produce a clear error rather than incorrect
wrapped values.
- Around line 164-165: The code uses direct indexing layer.inputs[0] and
layer.inputs[1] when creating initializer_a and initializer_b which can panic;
replace those direct accesses with the safe helper get_input_name (as used in
apply) to obtain the input names and pass them into get_optional_w_or_b (e.g.,
call get_input_name(layer, 0)? and get_input_name(layer, 1)? before
get_optional_w_or_b), so errors are propagated instead of panicking and you
still use the existing get_optional_w_or_b, layer_context, and layer inputs
logic.
- Around line 105-107: The calls to logup_ctx.range_check::<C, Builder>(api,
remainder_bound, divisor_bits) (and the two similar calls later) use .expect()
which panics; change the surrounding closure(s) to return Result<..., E> and
propagate range_check errors instead of unwrapping—replace .expect(...) with the
? operator (or map_err + ?) and refactor the iterator chain to try_fold or
collect into a Result<Vec<_>, E> so failures return Err up the call chain;
ensure the function signature for the enclosing closure and its caller are
updated to return Result and propagate errors accordingly.
🧹 Nitpick comments (4)
pyproject.toml (1)
27-34: Consider aligning pytest version constraints betweendevandtestgroups.The
devgroup specifiespytest>=8.3.5whiletestpins topytest==8.3.5. This inconsistency could lead to different pytest versions in development vs. test environments, potentially causing subtle behavior differences.Consider either pinning both or using the same constraint style for consistency.
python/core/model_processing/onnx_quantizer/layers/div.py (1)
62-67: Consider adding validation incheck_supported.The
check_supportedmethod is a no-op. Other quantizers may validate specific constraints (e.g., supported attributes, input counts). If Div has any limitations (e.g., divisor must be non-zero, specific tensor types), consider adding validation here.#!/bin/bash # Check how other quantizers implement check_supported rg -A 10 "def check_supported" --type py python/core/model_processing/onnx_quantizer/layers/python/tests/onnx_quantizer_tests/layers/div_config.py (1)
41-91: Consider adding error tests for invalid divisors.The Rust implementation explicitly validates that divisors must be positive and rejects non-positive values. Consider adding error test specs to verify this validation behavior, such as testing division by zero or negative divisors.
Additionally, note that the e2e tests (lines 80-91) call the RNG after the valid tests, so they will get different random divisor values than
initializer_divandbroadcast_div. If identical divisor values are intended for corresponding valid and e2e tests, consider storing the generated arrays in variables before building the specs.rust/jstprove_circuits/src/circuit_functions/layers/div.rs (1)
23-34: Unusedn_bitsfield.The
n_bitsfield is stored (line 175) but never used in theapplymethod. If this is intentional for future use or interface consistency, a brief comment would clarify the intent. Otherwise, consider removing it.
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (2)
Cargo.lockis excluded by!**/*.lockuv.lockis excluded by!**/*.lock
📒 Files selected for processing (7)
pyproject.tomlpython/core/model_processing/onnx_quantizer/layers/div.pypython/core/model_processing/onnx_quantizer/onnx_op_quantizer.pypython/tests/onnx_quantizer_tests/layers/div_config.pyrust/jstprove_circuits/src/circuit_functions/layers/div.rsrust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rsrust/jstprove_circuits/src/circuit_functions/layers/mod.rs
🧰 Additional context used
🧬 Code graph analysis (4)
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py (1)
python/core/model_processing/onnx_quantizer/layers/div.py (1)
DivQuantizer(17-67)
rust/jstprove_circuits/src/circuit_functions/layers/div.rs (4)
rust/jstprove_circuits/src/circuit_functions/gadgets/range_check.rs (1)
range_check(233-256)rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (3)
get_optional_w_or_b(131-140)extract_params_and_expected_shape(176-208)get_input_name(255-267)rust/jstprove_circuits/src/circuit_functions/utils/tensor_ops.rs (1)
load_array_constants_or_get_inputs(62-81)rust/jstprove_circuits/src/circuit_functions/utils/shaping.rs (1)
input(105-105)
rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs (1)
rust/jstprove_circuits/src/circuit_functions/layers/div.rs (1)
build(150-178)
python/tests/onnx_quantizer_tests/layers/div_config.py (1)
python/tests/onnx_quantizer_tests/layers/base.py (8)
LayerTestConfig(58-166)LayerTestSpec(27-55)e2e_test(247-248)edge_case_test(243-244)valid_test(235-236)description(175-177)override_initializer(187-189)tags(215-217)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: End-to-End Testing (3.11)
- GitHub Check: End-to-End Testing (3.10)
- GitHub Check: End-to-End Testing (3.12)
- GitHub Check: Check Formatting and Linting
🔇 Additional comments (11)
rust/jstprove_circuits/src/circuit_functions/layers/mod.rs (1)
6-6: LGTM!The new
divmodule is correctly declared as public and follows the existing alphabetical ordering convention.rust/jstprove_circuits/src/circuit_functions/layers/layer_kinds.rs (2)
3-3: LGTM!The
DivLayerimport follows the established pattern for layer imports.
137-137: LGTM!The
Divlayer is correctly registered in the layer registry, following the same pattern as other arithmetic operators (Add,Sub,Mul). The builder dispatch routes toDivLayer::buildas expected.python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py (2)
27-28: Minor: Import ordering.The
DivQuantizerimport is placed afterConstantQuantizerbut beforeConvQuantizer. For consistency with alphabetical ordering, consider placing it afterConstantQuantizerand beforeConvQuantizer(which it currently is), or grouping imports differently.
94-94: LGTM!The
DivQuantizeris correctly registered for the "Div" operation type and receivesself.new_initializers, consistent with other arithmetic quantizers likeAddQuantizer,SubQuantizer, andMulQuantizer.python/tests/onnx_quantizer_tests/layers/div_config.py (3)
1-11: LGTM!Imports are clean and align with the base test infrastructure patterns shown in the relevant snippets.
14-25: Good documentation on the quantization rationale.The docstring effectively explains why tests focus on constant divisors, which provides valuable context for maintainers.
27-39: LGTM!Configuration correctly sets up B as an initializer (constant divisor) rather than a dynamic input, matching the quantization rationale in the docstring.
rust/jstprove_circuits/src/circuit_functions/layers/div.rs (3)
1-21: LGTM!Imports and constants are appropriate for the fixed-point division implementation.
52-81: LGTM!Input validation correctly enforces positive divisors. The broadcast operation and error handling are well-structured.
167-178: LGTM!The struct construction and return are correct. The layer properly captures all necessary state for the division operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
.github/workflows/security.yml (1)
28-32: Toolchain version update looks reasonable.The nightly version bump from
2025-03-27to2025-07-15aligns with the new Rust code (DivLayer) introduced in this PR.Note: The
actions-rs/toolchain@v1action is archived/deprecated. Consider migrating todtolnay/rust-toolchainin a follow-up for better long-term maintenance.rust/jstprove_circuits/src/circuit_functions/layers/div.rs (1)
165-175: Consider validating the divisor initializer at build time.The divisor (
initializer_b) is required for this layer to function, but validation only occurs at runtime inapply(). Validating duringbuild()would provide earlier feedback and clearer error messages at circuit construction time rather than during execution.♻️ Proposed refactor for early validation
let dividend_name = get_input_name(&layer.inputs, 0, LayerKind::Div, INPUT)?; let divisor_name = get_input_name(&layer.inputs, 1, LayerKind::Div, INPUT)?; let initializer_a = get_optional_w_or_b(layer_context, dividend_name)?; let initializer_b = get_optional_w_or_b(layer_context, divisor_name)?; + if initializer_b.is_none() { + return Err(LayerError::Other { + layer: LayerKind::Div, + msg: "Div requires constant divisor (initializer_b). Dynamic divisors not supported.".to_string(), + }.into()); + } + Ok(Box::new(Self { inputs: layer.inputs.clone(), outputs: layer.outputs.clone(), initializer_a, initializer_b, }))
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
.github/workflows/security.ymlpython/core/model_processing/onnx_quantizer/layers/div.pypython/core/model_processing/onnx_quantizer/onnx_op_quantizer.pypython/tests/onnx_quantizer_tests/layers/div_config.pyrust/jstprove_circuits/src/circuit_functions/layers/div.rs
🚧 Files skipped from review as they are similar to previous changes (2)
- python/core/model_processing/onnx_quantizer/layers/div.py
- python/tests/onnx_quantizer_tests/layers/div_config.py
🧰 Additional context used
🧬 Code graph analysis (2)
rust/jstprove_circuits/src/circuit_functions/layers/div.rs (3)
rust/jstprove_circuits/src/circuit_functions/gadgets/range_check.rs (2)
range_check(233-256)new_default(208-210)rust/jstprove_circuits/src/circuit_functions/utils/onnx_model.rs (2)
get_optional_w_or_b(131-140)get_input_name(255-267)rust/jstprove_circuits/src/circuit_functions/utils/tensor_ops.rs (1)
load_array_constants_or_get_inputs(62-81)
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py (1)
python/core/model_processing/onnx_quantizer/layers/div.py (1)
DivQuantizer(17-67)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: End-to-End Testing (3.11)
- GitHub Check: End-to-End Testing (3.10)
- GitHub Check: End-to-End Testing (3.12)
- GitHub Check: Check Formatting and Linting
- GitHub Check: Rust Security Audit
🔇 Additional comments (8)
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py (2)
28-28: LGTM!The import follows the established pattern and is correctly placed in alphabetical order among the other layer quantizer imports.
94-94: LGTM!The
Divhandler registration correctly follows the established pattern used by other arithmetic operators (Add,Sub,Mul) that require shared access tonew_initializersfor creating cast tensors.rust/jstprove_circuits/src/circuit_functions/layers/div.rs (6)
1-17: LGTM!Imports and the
SHIFT_BITSconstant are appropriate for the fixed-point division implementation.
19-25: LGTM!The struct design is consistent with other layer implementations and properly captures the optional initializers.
43-76: LGTM!The divisor validation correctly enforces the
[1, 2^32)range, and the broadcasting with proper error handling is well implemented.
101-109: LGTM!The remainder bound check correctly validates
0 <= remainder < divisorusing the appropriate bit width derived from the divisor value.
140-154: LGTM!The LogUp context lifecycle is properly managed (init → range checks → finalize), and the output tensor construction with error handling is correctly implemented.
127-138: Document the dividend magnitude constraint imposed by the 24-bit range checks.The range checks on lines 128 and 134 constrain the dividend magnitude to fit within ±2^24 bits. While mathematically sound, this constraint lacks documentation explaining whether the quantization scheme guarantees dividends remain within these bounds. If quantized intermediate values can exceed ±2^24 in magnitude, the proof will silently fail due to the range check constraint.
Add a comment or documentation clarifying: (1) the expected dividend range given the quantization scheme, or (2) parameterize the bit width based on the quantization exponent to avoid future issues if the scheme changes.
Summary by CodeRabbit
New Features
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.
Summary by CodeRabbit
New Features
Chores
✏️ Tip: You can customize this high-level summary in your review settings.