Feature/quant compare mode#4
Conversation
…wer flip-rate analysis - Introduce the `compare` CLI command and runner method to evaluate two arbitrary models or quantizations across different backends. - Implement category-aware answer extraction (numeric, boolean, code, JSON) to calculate semantic flip rates, comparing functional answers instead of relying solely on text similarity. - Add model resolution logic to support explicit backend prefixes (e.g., `ollama:`, `mlx:`, `gguf:`) and HuggingFace repo heuristics. - Introduce the `CompareResult` data model for capturing and serializing evaluation metrics. - Add comprehensive unit tests for answer extraction, model resolution, CLI, and runner logic.
…ross backends. Add `distributions` field to `InferenceResult` and update `openai_compat`, `mlx_lm`, and `llama_cpp` backends to parse and populate token probability distributions. Implement KL divergence computation in the runner's comparison logic, and include `mean_kl_divergence` statistics in both sweep and compare result summaries.
…on heuristics. - Introduce `get_backend_for_model` to encapsulate model resolution and backend instantiation. - Update `sweep`, `stress`, and `determinism` CLI commands to auto-detect the backend if not explicitly provided. - Improve HuggingFace repository heuristics in `resolve.py` to better identify MLX and GGUF models, changing the default GGUF backend from `openai-compat` to `llama-cpp`. - Fix a minor type checking issue in the `mlx_lm` backend adapter. - Add and update unit tests for the new model resolution heuristics.
- Replace the placeholder in the CLI with an actual call to `generate_report` for the `compare` command. - Add support for loading `CompareResult` instances in the HTML report generator.
Introduce `quant-sensitive.jsonl` with 20 new prompts targeting multi-digit arithmetic, long chain-of-thought, and precise syntax. Update the README to document the new suite in the bundled suites list and command examples.
There was a problem hiding this comment.
Pull request overview
This PR adds a new “compare” workflow for head-to-head model/quant comparisons (including backend auto-detection), plus supporting analysis and prompt-suite updates to better measure quantization sensitivity.
Changes:
- Introduces a
compareCLI command,TestRunner.compare(), andCompareResultto report flip-rate, similarity, (attempted) KL divergence, and per-category stats. - Adds category-aware answer extraction (
analysis.answer_extract) and new unit tests for resolver/compare/answer-extraction. - Adds a bundled
quant-sensitive.jsonlsuite and updates docs; removes the repo-levelprompt-suites/*.jsonlcontents (migration toward packaged suites).
Reviewed changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
src/infer_check/cli.py |
Adds compare command; switches other commands to backend auto-detection; prints compare summaries |
src/infer_check/runner.py |
Adds TestRunner.compare() and KL computation in _compare(); writes compare checkpoints |
src/infer_check/types.py |
Adds CompareResult; extends InferenceResult with distributions |
src/infer_check/resolve.py |
New model-spec resolver for backend + URL inference |
src/infer_check/backends/base.py |
Adds get_backend_for_model() (auto-resolve backend if not provided) |
src/infer_check/backends/mlx_lm.py |
Captures per-token distributions/logprobs from MLX generate-step; changes batch generation to sequential |
src/infer_check/backends/openai_compat.py |
Captures top_logprobs into distributions |
src/infer_check/backends/llama_cpp.py |
Requests n_probs; captures per-token probability lists into distributions |
src/infer_check/analysis/answer_extract.py |
New answer extraction + flip-rate logic |
src/infer_check/analysis/__init__.py |
Re-exports answer-extraction API |
src/infer_check/reporting/html.py |
Loads CompareResult files into report pipeline |
src/infer_check/prompt_suites/quant-sensitive.jsonl |
Adds new bundled quant-sensitive suite |
README.md |
Documents the new bundled suite |
tests/unit/test_compare.py |
Adds CLI + runner tests for compare (includes an unfinished stub test) |
tests/unit/test_resolve.py |
Adds resolver tests |
tests/unit/test_answer_extract.py |
Adds answer-extraction tests |
.pre-commit-config.yaml |
Adds mlx-lm to mypy hook env |
prompt-suites/*.jsonl |
Removes repo-level suite contents (migration/deprecation) |
Comments suppressed due to low confidence (1)
tests/unit/test_compare.py:343
test_checkpoint_writtenis currently a stub (only a docstring + commented-out code) and doesn't assert anything about checkpoint creation. Either implement the test (e.g., runcompare()and assert acompare_*_vs_*_*.jsonfile appears undercache_dir) or remove it to avoid dead/unfinished tests.
def test_checkpoint_written(self, test_runner: TestRunner, tmp_path: Path) -> None:
"""A checkpoint file is written to cache_dir."""
# prompts = _make_prompts(1)
# resp = {prompts[0].id: "ok"}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Welcome to Codecov 🎉Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests. Thanks for integrating Codecov - We've got you covered ☂️ |
…rgence calculation. Add `distribution_metadata` to `InferenceResult` to store token alignment data. Update the MLX, llama.cpp, and OpenAI-compatible backends to extract and populate this metadata. Refactor the `_compare` method in the runner to align varying probability distributions (e.g., full vocab vs top-K) before computing KL divergence. Additionally, fix an incorrect prompt suite path in the test loader.
fix: convert linear probabilities to log-probabilities in llama.cpp backend. The backend now applies `math.log` with a `1e-10` epsilon to the probabilities returned by `llama-server`, ensuring the data format matches the log-space expectations of the rest of the codebase (e.g., for KL divergence calculations).
…ivergence calculation for OpenAI-compatible backends.
…ion fallback - Sanitize model labels in `TestRunner.compare` to prevent path traversal and filesystem errors when creating checkpoint files. - Update `answers_match` to fallback to raw sequence similarity when extraction fails, rather than a strict boolean check. - Fix the `compare` CLI output to only print "No answer flips detected" when there are genuinely no flipped prompts. - Update `resolve.py` docstrings and CLI help text to reflect current behavior and broader command usage.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 24 out of 24 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if backend_type: | ||
| config = BackendConfig( | ||
| backend_type=backend_type, # type: ignore | ||
| model_id=model_str, | ||
| base_url=base_url, | ||
| quantization=quantization, | ||
| ) | ||
| else: | ||
| resolved = resolve_model(model_str, base_url=base_url) | ||
| config = BackendConfig( | ||
| backend_type=resolved.backend, | ||
| model_id=resolved.model_id, | ||
| base_url=resolved.base_url, | ||
| quantization=quantization or resolved.label, | ||
| ) | ||
|
|
There was a problem hiding this comment.
When backend_type is provided, get_backend_for_model() passes model_str through as model_id without stripping an explicit prefix (e.g. ollama:..., gguf:...). This makes forced-backend usage inconsistent with resolve_model() and can easily produce invalid model_ids for the backend. Consider always calling resolve_model() first to normalize model_id/base_url, and then overriding backend_type if the flag is set.
| if backend_type: | |
| config = BackendConfig( | |
| backend_type=backend_type, # type: ignore | |
| model_id=model_str, | |
| base_url=base_url, | |
| quantization=quantization, | |
| ) | |
| else: | |
| resolved = resolve_model(model_str, base_url=base_url) | |
| config = BackendConfig( | |
| backend_type=resolved.backend, | |
| model_id=resolved.model_id, | |
| base_url=resolved.base_url, | |
| quantization=quantization or resolved.label, | |
| ) | |
| # Always normalize the model string first to ensure consistent model_id/base_url | |
| resolved = resolve_model(model_str, base_url=base_url) | |
| config = BackendConfig( | |
| backend_type=backend_type or resolved.backend, # type: ignore | |
| model_id=resolved.model_id, | |
| base_url=resolved.base_url, | |
| quantization=quantization or resolved.label, | |
| ) |
| dist_indices = cast(list[int], top_k_indices.tolist()) | ||
|
|
||
| distributions.append(dist_list) | ||
| meta: dict[str, int | str] = {"is_aligned": 1} |
There was a problem hiding this comment.
is_aligned is set to 1 for MLX distributions even though the backend is only storing top‑K entries (sorted by probability). This causes _compare() to treat the arrays as index-aligned and compute KL on mismatched token IDs, producing meaningless results. Consider removing is_aligned here (so union-by-id alignment is used) or only setting it when you actually store full-vocab distributions where indices are token IDs.
| meta: dict[str, int | str] = {"is_aligned": 1} | |
| meta: dict[str, int | str] = {} |
| # Get top-K indices and values | ||
| top_k_indices = mx.argpartition(-logprob_dist, top_k - 1)[:top_k] |
There was a problem hiding this comment.
top_k_logprobs is used directly in mx.argpartition(..., top_k - 1) without bounds checking. If a prompt sets top_k_logprobs to 0, 1, or larger than the vocab size, this will raise at runtime. Clamp top_k to [0, vocab_size] (and handle top_k < 1 early) before calling argpartition.
| # Get top-K indices and values | |
| top_k_indices = mx.argpartition(-logprob_dist, top_k - 1)[:top_k] | |
| # Clamp K to the vocabulary size to avoid out-of-bounds issues. | |
| vocab_size = int(logprob_dist.shape[0]) | |
| if vocab_size <= 0: | |
| # Nothing to record for this step. | |
| continue | |
| effective_top_k = int(top_k) | |
| if effective_top_k < 1: | |
| # Should not happen due to the outer condition, but guard defensively. | |
| continue | |
| if effective_top_k > vocab_size: | |
| effective_top_k = vocab_size | |
| # Get top-K indices and values | |
| top_k_indices = mx.argpartition(-logprob_dist, effective_top_k - 1)[:effective_top_k] |
| # MLX repos (mlx-community org or -mlx suffix). | ||
| if ( | ||
| spec_lower.startswith("mlx-community/") | ||
| or spec_lower.endswith("-mlx") | ||
| or "mlx" in spec_lower |
There was a problem hiding this comment.
The heuristic or "mlx" in spec_lower will match unrelated model names (e.g. "complex-model" contains "mlx") and incorrectly route them to the mlx-lm backend. This makes backend detection unreliable. Consider tightening the check to explicit patterns (org mlx-community/, -mlx suffix, /mlx- segment, etc.) instead of a raw substring match.
| # MLX repos (mlx-community org or -mlx suffix). | |
| if ( | |
| spec_lower.startswith("mlx-community/") | |
| or spec_lower.endswith("-mlx") | |
| or "mlx" in spec_lower | |
| # MLX repos (mlx-community org, -mlx suffix, or /mlx- segment). | |
| if ( | |
| spec_lower.startswith("mlx-community/") | |
| or spec_lower.endswith("-mlx") | |
| or "/mlx-" in spec_lower | |
| or spec_lower.startswith("mlx-") |
| {"text": "Calculate 123456 * 987654. Show the step-by-step multiplication.", "category": "multi_digit_arithmetic", "max_tokens": 1024} | ||
| {"text": "What is the 10th root of 2? Provide the result to 10 decimal places.", "category": "precision_numerics", "max_tokens": 256} | ||
| {"text": "Determine if 9999999999999997 is a prime number. Explain your reasoning in detail.", "category": "large_number_reasoning", "max_tokens": 1024} | ||
| {"text": "A farmer has 17 sheep. All but 9 die. How many are left? Explain the logic step-by-step.", "category": "logical_puzzle", "max_tokens": 256} | ||
| {"text": "If a train leaves station A at 60 mph and another leaves station B at 90 mph, and they are 300 miles apart, when and where do they meet? Provide a detailed derivation.", "category": "long_chain_of_thought", "max_tokens": 512} | ||
| {"text": "Write a valid YAML configuration for a Kubernetes Deployment with a custom health check and resource limits. Ensure the indentation is perfect.", "category": "precise_syntax", "max_tokens": 512} | ||
| {"text": "Generate a complex nested JSON object with at least 5 levels of nesting, containing arrays of objects with mixed types. Ensure it's valid JSON.", "category": "precise_syntax", "max_tokens": 512} | ||
| {"text": "Translate this SQL query into a Python list comprehension: SELECT name FROM users WHERE age > 21 AND city = 'New York' ORDER BY name LIMIT 10", "category": "code_translation", "max_tokens": 256} | ||
| {"text": "Explain the difference between a shallow copy and a deep copy in Python with code examples showing the internal memory addresses using id().", "category": "long_chain_of_thought", "max_tokens": 1024} | ||
| {"text": "Solve for x: log2(x) + log2(x-2) = 3. Show every step of the algebraic manipulation.", "category": "algebraic_reasoning", "max_tokens": 512} | ||
| {"text": "What is the 50th Fibonacci number? Calculate it precisely without using scientific notation.", "category": "large_number_reasoning", "max_tokens": 512} | ||
| {"text": "Describe the steps of the SHA-256 hashing algorithm at a high level, but include the specific constants used in the initial hash values.", "category": "precision_numerics", "max_tokens": 1024} | ||
| {"text": "Write a Rust function that uses unsafe code to swap two integers using pointers. Explain why it is unsafe.", "category": "precise_syntax", "max_tokens": 512} | ||
| {"text": "Create a strictly formatted CSV list of the first 20 prime numbers, separated by semicolons, with each value enclosed in double quotes.", "category": "precise_syntax", "max_tokens": 256} | ||
| {"text": "If a clock shows 3:15, what is the exact angle between the hour and minute hands? Provide the derivation.", "category": "logical_puzzle", "max_tokens": 256} | ||
| {"text": "Compare the time complexity of QuickSort and MergeSort in the worst case, providing mathematical proofs for the Big O notation of each.", "category": "long_chain_of_thought", "max_tokens": 1024} | ||
| {"text": "Write a regular expression that matches valid email addresses according to RFC 5322, and explain each part of the regex.", "category": "precise_syntax", "max_tokens": 512} | ||
| {"text": "What is the exact value of e (Euler's number) to 30 decimal places?", "category": "precision_numerics", "max_tokens": 256} | ||
| {"text": "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Explain your reasoning.", "category": "logical_puzzle", "max_tokens": 256} | ||
| {"text": "Analyze the potential for race conditions in a multithreaded Python program that increments a global counter without locks. Provide a code example and explain how to fix it using threading.Lock.", "category": "long_chain_of_thought", "max_tokens": 1024} |
There was a problem hiding this comment.
The new suite uses categories like multi_digit_arithmetic, precision_numerics, etc., but answer_extract.extract_answer() only treats a small fixed set of categories as numeric/boolean/code/json. As a result, flip-rate for this suite will mostly fall back to the low-confidence raw strategy (text similarity) rather than numeric/structured comparisons. Either map these new categories to existing ones or extend answer_extract’s category sets to include them.
| {"text": "Calculate 123456 * 987654. Show the step-by-step multiplication.", "category": "multi_digit_arithmetic", "max_tokens": 1024} | |
| {"text": "What is the 10th root of 2? Provide the result to 10 decimal places.", "category": "precision_numerics", "max_tokens": 256} | |
| {"text": "Determine if 9999999999999997 is a prime number. Explain your reasoning in detail.", "category": "large_number_reasoning", "max_tokens": 1024} | |
| {"text": "A farmer has 17 sheep. All but 9 die. How many are left? Explain the logic step-by-step.", "category": "logical_puzzle", "max_tokens": 256} | |
| {"text": "If a train leaves station A at 60 mph and another leaves station B at 90 mph, and they are 300 miles apart, when and where do they meet? Provide a detailed derivation.", "category": "long_chain_of_thought", "max_tokens": 512} | |
| {"text": "Write a valid YAML configuration for a Kubernetes Deployment with a custom health check and resource limits. Ensure the indentation is perfect.", "category": "precise_syntax", "max_tokens": 512} | |
| {"text": "Generate a complex nested JSON object with at least 5 levels of nesting, containing arrays of objects with mixed types. Ensure it's valid JSON.", "category": "precise_syntax", "max_tokens": 512} | |
| {"text": "Translate this SQL query into a Python list comprehension: SELECT name FROM users WHERE age > 21 AND city = 'New York' ORDER BY name LIMIT 10", "category": "code_translation", "max_tokens": 256} | |
| {"text": "Explain the difference between a shallow copy and a deep copy in Python with code examples showing the internal memory addresses using id().", "category": "long_chain_of_thought", "max_tokens": 1024} | |
| {"text": "Solve for x: log2(x) + log2(x-2) = 3. Show every step of the algebraic manipulation.", "category": "algebraic_reasoning", "max_tokens": 512} | |
| {"text": "What is the 50th Fibonacci number? Calculate it precisely without using scientific notation.", "category": "large_number_reasoning", "max_tokens": 512} | |
| {"text": "Describe the steps of the SHA-256 hashing algorithm at a high level, but include the specific constants used in the initial hash values.", "category": "precision_numerics", "max_tokens": 1024} | |
| {"text": "Write a Rust function that uses unsafe code to swap two integers using pointers. Explain why it is unsafe.", "category": "precise_syntax", "max_tokens": 512} | |
| {"text": "Create a strictly formatted CSV list of the first 20 prime numbers, separated by semicolons, with each value enclosed in double quotes.", "category": "precise_syntax", "max_tokens": 256} | |
| {"text": "If a clock shows 3:15, what is the exact angle between the hour and minute hands? Provide the derivation.", "category": "logical_puzzle", "max_tokens": 256} | |
| {"text": "Compare the time complexity of QuickSort and MergeSort in the worst case, providing mathematical proofs for the Big O notation of each.", "category": "long_chain_of_thought", "max_tokens": 1024} | |
| {"text": "Write a regular expression that matches valid email addresses according to RFC 5322, and explain each part of the regex.", "category": "precise_syntax", "max_tokens": 512} | |
| {"text": "What is the exact value of e (Euler's number) to 30 decimal places?", "category": "precision_numerics", "max_tokens": 256} | |
| {"text": "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Explain your reasoning.", "category": "logical_puzzle", "max_tokens": 256} | |
| {"text": "Analyze the potential for race conditions in a multithreaded Python program that increments a global counter without locks. Provide a code example and explain how to fix it using threading.Lock.", "category": "long_chain_of_thought", "max_tokens": 1024} | |
| {"text": "Calculate 123456 * 987654. Show the step-by-step multiplication.", "category": "numeric", "max_tokens": 1024} | |
| {"text": "What is the 10th root of 2? Provide the result to 10 decimal places.", "category": "numeric", "max_tokens": 256} | |
| {"text": "Determine if 9999999999999997 is a prime number. Explain your reasoning in detail.", "category": "numeric", "max_tokens": 1024} | |
| {"text": "A farmer has 17 sheep. All but 9 die. How many are left? Explain the logic step-by-step.", "category": "numeric", "max_tokens": 256} | |
| {"text": "If a train leaves station A at 60 mph and another leaves station B at 90 mph, and they are 300 miles apart, when and where do they meet? Provide a detailed derivation.", "category": "numeric", "max_tokens": 512} | |
| {"text": "Write a valid YAML configuration for a Kubernetes Deployment with a custom health check and resource limits. Ensure the indentation is perfect.", "category": "code", "max_tokens": 512} | |
| {"text": "Generate a complex nested JSON object with at least 5 levels of nesting, containing arrays of objects with mixed types. Ensure it's valid JSON.", "category": "json", "max_tokens": 512} | |
| {"text": "Translate this SQL query into a Python list comprehension: SELECT name FROM users WHERE age > 21 AND city = 'New York' ORDER BY name LIMIT 10", "category": "code", "max_tokens": 256} | |
| {"text": "Explain the difference between a shallow copy and a deep copy in Python with code examples showing the internal memory addresses using id().", "category": "code", "max_tokens": 1024} | |
| {"text": "Solve for x: log2(x) + log2(x-2) = 3. Show every step of the algebraic manipulation.", "category": "numeric", "max_tokens": 512} | |
| {"text": "What is the 50th Fibonacci number? Calculate it precisely without using scientific notation.", "category": "numeric", "max_tokens": 512} | |
| {"text": "Describe the steps of the SHA-256 hashing algorithm at a high level, but include the specific constants used in the initial hash values.", "category": "numeric", "max_tokens": 1024} | |
| {"text": "Write a Rust function that uses unsafe code to swap two integers using pointers. Explain why it is unsafe.", "category": "code", "max_tokens": 512} | |
| {"text": "Create a strictly formatted CSV list of the first 20 prime numbers, separated by semicolons, with each value enclosed in double quotes.", "category": "code", "max_tokens": 256} | |
| {"text": "If a clock shows 3:15, what is the exact angle between the hour and minute hands? Provide the derivation.", "category": "numeric", "max_tokens": 256} | |
| {"text": "Compare the time complexity of QuickSort and MergeSort in the worst case, providing mathematical proofs for the Big O notation of each.", "category": "raw", "max_tokens": 1024} | |
| {"text": "Write a regular expression that matches valid email addresses according to RFC 5322, and explain each part of the regex.", "category": "code", "max_tokens": 512} | |
| {"text": "What is the exact value of e (Euler's number) to 30 decimal places?", "category": "numeric", "max_tokens": 256} | |
| {"text": "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Explain your reasoning.", "category": "numeric", "max_tokens": 256} | |
| {"text": "Analyze the potential for race conditions in a multithreaded Python program that increments a global counter without locks. Provide a code example and explain how to fix it using threading.Lock.", "category": "code", "max_tokens": 1024} |
| safe_a = resolved_a.label.replace("/", "_") | ||
| safe_b = resolved_b.label.replace("/", "_") | ||
| out_path = output / f"compare_{safe_a}_vs_{safe_b}.json" | ||
| compare_result.save(out_path) |
There was a problem hiding this comment.
safe_a / safe_b only replace /, but labels can still contain characters that are invalid on some filesystems (e.g. : on Windows) or create messy filenames when users pass --label-a/--label-b. Consider reusing the stricter sanitization used in TestRunner.compare() (or a shared helper) when constructing out_path.
| @click.option("--backend", default=None, help="Backend type (auto-detected if omitted).") | ||
| @click.option( |
There was a problem hiding this comment.
Now that --backend defaults to None (auto-detect), downstream code that uses the backend option value as part of output filenames/logging will end up with confusing ..._None.json names when the flag is omitted. Consider either omitting the backend suffix when backend is None, or deriving a stable identifier from the resolved backend(s).
| # prompts = _make_prompts(1) | ||
| # resp = {prompts[0].id: "ok"} |
There was a problem hiding this comment.
test_checkpoint_written is currently a stub with all assertions commented out, so it doesn’t validate the checkpoint-writing behavior it claims to test. Either complete the test (e.g., run compare() with 1 prompt and assert a compare_*.json file appears in cache_dir) or remove the placeholder to avoid dead test code.
| # prompts = _make_prompts(1) | |
| # resp = {prompts[0].id: "ok"} | |
| prompts = _make_prompts(1) | |
| resp = {prompts[0].id: "ok"} | |
| backend_a = self._make_mock_backend("a", resp) | |
| backend_b = self._make_mock_backend("b", resp) | |
| # Run a compare so that a checkpoint should be written. | |
| asyncio.run( | |
| test_runner.compare( | |
| backend_a=backend_a, | |
| backend_b=backend_b, | |
| prompts=prompts, | |
| ) | |
| ) | |
| cache_dir = Path(test_runner.cache_dir) | |
| assert cache_dir.exists() and cache_dir.is_dir() | |
| checkpoint_files = list(cache_dir.glob("compare_*.json")) | |
| assert checkpoint_files, "Expected at least one compare_*.json checkpoint file to be written" |
This pull request updates the prompt suite infrastructure and documentation for the project, with a focus on prompt suite management and test coverage. The most significant changes are the addition of a new "quant-sensitive" prompt suite, updates to documentation to reflect this, and the removal of all prompts from several existing suite files—likely as part of a migration, refactor, or deprecation. There is also a minor dependency update to support new model integrations.
Prompt suite and documentation updates:
quant-sensitive.jsonlprompt suite (20 prompts) targeting multi-digit arithmetic, long chain-of-thought, and precise syntax; updated the documentation to include this suite as a bundled option. [1] [2]adversarial-numerics.jsonl,code.jsonl, anddeterminism.jsonl, effectively emptying these files. This may indicate a migration, deprecation, or preparation for new content. [1] [2] [3]Dependency management:
mlx-lmas a dependency to themypypre-commit hook environment in.pre-commit-config.yaml, likely to enable type checking or integration for the new model.