diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3791b5a..c33bda3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,6 +34,7 @@ repos: - click>=8.1.0 - httpx>=0.27.0 - types-jinja2 + - mlx-lm<0.31.0 args: ["--strict", "--ignore-missing-imports"] pass_filenames: false entry: mypy src/infer_check/ diff --git a/README.md b/README.md index 8396767..fc43b96 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ infer-check sweep \ --output ./results/sweep/ ``` -`--prompts` accepts either a bundled suite name (`reasoning`, `code`, `adversarial-numerics`, `determinism`, `long-context`) or a path to any `.jsonl` file. +`--prompts` accepts either a bundled suite name (`reasoning`, `code`, `adversarial-numerics`, `determinism`, `long-context`, `quant-sensitive`) or a path to any `.jsonl` file. The baseline is automatically run twice as a self-check — if it's not 50/50 identical, your comparison data is unreliable. @@ -158,6 +158,7 @@ Curated prompts targeting known quantization failure modes: | `code.jsonl` | 49 | Python, JSON, SQL generation | | `adversarial-numerics.jsonl` | 30 | IEEE 754 edge cases, overflow, precision | | `long-context.jsonl` | 10 | Tables and transcripts with recall questions | +| `quant-sensitive.jsonl` | 20 | Multi-digit arithmetic, long CoT, precise syntax | | `determinism.jsonl` | 50 | High-entropy continuations for determinism testing | All suites ship with the package — no need to clone the repo. Custom suites are JSONL files with one object per line: diff --git a/prompt-suites/adversarial-numerics.jsonl b/prompt-suites/adversarial-numerics.jsonl deleted file mode 100644 index 6baf154..0000000 --- a/prompt-suites/adversarial-numerics.jsonl +++ /dev/null @@ -1,30 +0,0 @@ -{"text": "What is 0.1 + 0.2? Give a precise decimal answer.", "category": "floating_point", "max_tokens": 256} -{"text": "Compute 9999999999999999 + 1.", "category": "large_numbers", "max_tokens": 256} -{"text": "What is 10^308?", "category": "large_numbers", "max_tokens": 256} -{"text": "What is 1/3 expressed as a decimal to 20 places?", "category": "precision", "max_tokens": 256} -{"text": "Is 2^31 - 1 equal to 2147483647?", "category": "large_numbers", "max_tokens": 256} -{"text": "What is the decimal representation of 1/7?", "category": "precision", "max_tokens": 256} -{"text": "Compute: (-1)^(1/3). Is it -1 or a complex number?", "category": "edge_case", "max_tokens": 256} -{"text": "What is 0^0?", "category": "edge_case", "max_tokens": 256} -{"text": "What is the value of 1e-300 * 1e-300?", "category": "underflow", "max_tokens": 256} -{"text": "Compute 170! (170 factorial). How many digits does it have?", "category": "large_numbers", "max_tokens": 256} -{"text": "Write the number 1000000000000 in words.", "category": "formatting", "max_tokens": 256} -{"text": "What is -0.0 equal to? Is -0.0 == 0.0 in IEEE 754?", "category": "edge_case", "max_tokens": 256} -{"text": "Convert $1,234,567.89 to Japanese yen at a rate of 149.32 yen per dollar.", "category": "precision", "max_tokens": 256} -{"text": "What is infinity minus infinity?", "category": "edge_case", "max_tokens": 256} -{"text": "How many seconds are in a leap year? Show your calculation.", "category": "precision", "max_tokens": 256} -{"text": "What is floor(-2.5)? What is round(-2.5)?", "category": "edge_case", "max_tokens": 256} -{"text": "Express the speed of light (299792458 m/s) in miles per hour.", "category": "precision", "max_tokens": 256} -{"text": "What is 2^53 + 1? Can this number be exactly represented as a 64-bit float?", "category": "floating_point", "max_tokens": 256} -{"text": "Compute the sum: 1 + 1/2 + 1/4 + 1/8 + ... (infinite geometric series)", "category": "precision", "max_tokens": 256} -{"text": "What is NaN == NaN in IEEE 754 floating point?", "category": "edge_case", "max_tokens": 256} -{"text": "How many grains of sand are on Earth? Express in scientific notation.", "category": "large_numbers", "max_tokens": 256} -{"text": "What is 7/13 * 13/7? Is the result exactly 1?", "category": "floating_point", "max_tokens": 256} -{"text": "Convert the binary number 10110011.101 to decimal.", "category": "formatting", "max_tokens": 256} -{"text": "What is 999...9 (fifty 9s) + 1?", "category": "large_numbers", "max_tokens": 256} -{"text": "Compute 1/0. What happens? Now compute 1/0.0 in IEEE 754.", "category": "edge_case", "max_tokens": 256} -{"text": "Express π to 30 decimal places.", "category": "precision", "max_tokens": 256} -{"text": "What is MAX_INT in a 64-bit signed integer? What happens if you add 1 to it?", "category": "edge_case", "max_tokens": 256} -{"text": "How many microseconds are in 3 years, 7 months, and 12 days?", "category": "precision", "max_tokens": 512} -{"text": "What is 1/998001? This produces an interesting decimal pattern.", "category": "precision", "max_tokens": 512} -{"text": "Is 0.9999999999 (ten 9s) equal to 1? What about 0.999... (infinitely repeating)?", "category": "floating_point", "max_tokens": 256} diff --git a/prompt-suites/code.jsonl b/prompt-suites/code.jsonl deleted file mode 100644 index 0f6d655..0000000 --- a/prompt-suites/code.jsonl +++ /dev/null @@ -1,49 +0,0 @@ -{"text": "Write a Python function that checks if a string is a valid IPv4 address.", "category": "python", "max_tokens": 512} -{"text": "Generate a valid JSON object representing a user profile with name, email, age, and a list of 3 hobbies.", "category": "json", "max_tokens": 256} -{"text": "Write a Python function to find the longest common subsequence of two strings.", "category": "python", "max_tokens": 512} -{"text": "Here is a buggy function:\ndef fib(n):\n if n <= 1:\n return 1\n return fib(n-1) + fib(n-2)\nWhat is wrong with it? Fix it so fib(0)=0, fib(1)=1.", "category": "debugging", "max_tokens": 512} -{"text": "Write a Python class called BankAccount with deposit, withdraw, and balance methods. Withdrawals should fail if insufficient funds.", "category": "python", "max_tokens": 512} -{"text": "Generate a JSON array of 5 objects, each with fields: id (integer), name (string), score (float between 0 and 1), tags (array of strings).", "category": "json", "max_tokens": 512} -{"text": "Write a Python function that merges two sorted lists into one sorted list without using the built-in sort.", "category": "python", "max_tokens": 512} -{"text": "Complete this Python code:\nimport re\ndef extract_emails(text):\n \"\"\"Return all email addresses found in text.\"\"\"", "category": "completion", "max_tokens": 256} -{"text": "Write a SQL query that finds the second highest salary from an employees table.", "category": "sql", "max_tokens": 256} -{"text": "Here is buggy code:\ndef binary_search(arr, target):\n low, high = 0, len(arr)\n while low < high:\n mid = (low + high) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n low = mid\n else:\n high = mid\n return -1\nThis has an infinite loop bug. Find and fix it.", "category": "debugging", "max_tokens": 512} -{"text": "Write a Python generator function that yields prime numbers indefinitely.", "category": "python", "max_tokens": 512} -{"text": "Generate a JSON Schema that validates an object with required fields: name (string, min 1 char), age (integer, min 0, max 150), email (string, format email).", "category": "json", "max_tokens": 512} -{"text": "Write a Python decorator that caches function results (memoization) using a dictionary.", "category": "python", "max_tokens": 512} -{"text": "Complete this code:\nclass Stack:\n def __init__(self):\n self._items = []\n \n def push(self, item):", "category": "completion", "max_tokens": 256} -{"text": "Write a Python function that converts a Roman numeral string to an integer.", "category": "python", "max_tokens": 512} -{"text": "Here is buggy code:\ndef flatten(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.extend(item)\n else:\n result.append(item)\n return result\nThis doesn't handle deeply nested lists. Fix it to work recursively.", "category": "debugging", "max_tokens": 512} -{"text": "Write a Python async function that fetches 3 URLs concurrently using asyncio and aiohttp.", "category": "python", "max_tokens": 512} -{"text": "Generate valid YAML for a Kubernetes deployment with 3 replicas of an nginx container on port 80.", "category": "yaml", "max_tokens": 512} -{"text": "Write a Python function that implements the Levenshtein edit distance between two strings.", "category": "python", "max_tokens": 512} -{"text": "Complete this code:\ndef parse_csv_line(line: str) -> list[str]:\n \"\"\"Parse a CSV line handling quoted fields with commas inside.\"\"\"", "category": "completion", "max_tokens": 512} -{"text": "Write a Python context manager class that measures and prints execution time of a code block.", "category": "python", "max_tokens": 256} -{"text": "Generate a JSON Web Token (JWT) payload with fields: sub, iat (unix timestamp for now), exp (1 hour from now), role, permissions array.", "category": "json", "max_tokens": 256} -{"text": "Write a Python function that takes a nested dictionary and flattens it with dot-notation keys. Example: {'a': {'b': 1}} -> {'a.b': 1}", "category": "python", "max_tokens": 512} -{"text": "Here is code with a subtle bug:\ndef unique_chars(s):\n return len(s) == len(set(s))\nDoes this work for all Unicode strings? What about combining characters?", "category": "debugging", "max_tokens": 512} -{"text": "Write a Python function that validates whether a string of parentheses, brackets, and braces is balanced.", "category": "python", "max_tokens": 256} -{"text": "Generate a JSON object representing an API error response following RFC 7807 (Problem Details).", "category": "json", "max_tokens": 256} -{"text": "Write a Python function that implements run-length encoding. 'AAABBBCCCC' -> '3A3B4C'", "category": "python", "max_tokens": 256} -{"text": "Write a Python type-annotated function that takes a list of dicts and groups them by a specified key.", "category": "python", "max_tokens": 512} -{"text": "Complete this code to implement a simple LRU cache:\nclass LRUCache:\n def __init__(self, capacity: int):", "category": "completion", "max_tokens": 512} -{"text": "Write a regular expression that matches valid email addresses. Explain each part.", "category": "python", "max_tokens": 512} -{"text": "Here is buggy Python:\ndef quicksort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[0]\n left = [x for x in arr if x < pivot]\n right = [x for x in arr if x > pivot]\n return quicksort(left) + [pivot] + quicksort(right)\nWhat happens with duplicate elements? Fix it.", "category": "debugging", "max_tokens": 512} -{"text": "Write a Python function that converts a Python dictionary to a valid GraphQL query string.", "category": "python", "max_tokens": 512} -{"text": "Generate a JSON array where each element is a date string in ISO 8601 format for every Monday in March 2026.", "category": "json", "max_tokens": 256} -{"text": "Write a Python function that reads a file and returns the 10 most frequent words with their counts.", "category": "python", "max_tokens": 512} -{"text": "Complete this async generator:\nasync def read_chunks(stream, chunk_size=1024):\n \"\"\"Yield chunks from an async byte stream.\"\"\"", "category": "completion", "max_tokens": 256} -{"text": "Write a Python dataclass for a 2D Point with distance_to, midpoint, and __add__ methods.", "category": "python", "max_tokens": 512} -{"text": "Write a one-liner Python list comprehension that generates all Pythagorean triples where a,b,c < 100.", "category": "python", "max_tokens": 256} -{"text": "Generate a minimal valid HTML5 document with a title, a heading, a paragraph, and a link.", "category": "html", "max_tokens": 256} -{"text": "Write a Python function that converts a flat list of parent-child pairs into a tree structure.", "category": "python", "max_tokens": 512} -{"text": "Here is buggy code:\ndef safe_divide(a, b):\n try:\n return a / b\n except:\n return 0\nWhat are the problems with this error handling? Rewrite it properly.", "category": "debugging", "max_tokens": 512} -{"text": "Write a Python function that implements the Sieve of Eratosthenes up to n.", "category": "python", "max_tokens": 256} -{"text": "Generate a JSON object with intentionally tricky values: empty string, null, false, 0, empty array, nested nulls.", "category": "json", "max_tokens": 256} -{"text": "Write a Dockerfile for a Python 3.11 FastAPI app that listens on port 8000.", "category": "devops", "max_tokens": 256} -{"text": "Write a Python function that takes markdown text and extracts all URLs from it.", "category": "python", "max_tokens": 256} -{"text": "Complete this:\ndef retry(max_attempts=3, delay=1.0):\n \"\"\"Decorator that retries a function on exception.\"\"\"", "category": "completion", "max_tokens": 512} -{"text": "Write a Python function that serializes a datetime object to ISO 8601 string and deserializes it back, handling timezone-aware and naive datetimes.", "category": "python", "max_tokens": 512} -{"text": "What does this code print and why?\nx = [1, 2, 3]\ny = x\ny.append(4)\nprint(x)", "category": "debugging", "max_tokens": 256} -{"text": "Write a Python function that implements consistent hashing for distributing keys across N nodes.", "category": "python", "max_tokens": 512} -{"text": "Generate a JSON-LD object representing a Person with name, job title, and employer according to schema.org.", "category": "json", "max_tokens": 256} diff --git a/prompt-suites/determinism.jsonl b/prompt-suites/determinism.jsonl deleted file mode 100644 index 41e2e89..0000000 --- a/prompt-suites/determinism.jsonl +++ /dev/null @@ -1,50 +0,0 @@ -{"text": "Tell me a short story about a robot who learns to paint.", "category": "creative", "max_tokens": 256} -{"text": "Explain quantum entanglement in simple terms.", "category": "factual", "max_tokens": 256} -{"text": "List 5 unusual uses for a paperclip.", "category": "list", "max_tokens": 256} -{"text": "What would happen if gravity suddenly became 10% weaker?", "category": "creative", "max_tokens": 256} -{"text": "Describe the taste of water to someone who has never tasted anything.", "category": "creative", "max_tokens": 256} -{"text": "Name 7 animals that start with the letter K.", "category": "list", "max_tokens": 256} -{"text": "Write a haiku about debugging code.", "category": "creative", "max_tokens": 128} -{"text": "Explain why the sky is blue to a 5-year-old.", "category": "explanation", "max_tokens": 256} -{"text": "Give me a recipe for a meal using only eggs, rice, and soy sauce.", "category": "list", "max_tokens": 256} -{"text": "What are 3 arguments for and 3 against remote work?", "category": "list", "max_tokens": 512} -{"text": "Invent a word and define it.", "category": "creative", "max_tokens": 256} -{"text": "Explain the difference between a virus and a bacterium.", "category": "factual", "max_tokens": 256} -{"text": "Describe what a sunset looks like from the surface of Mars.", "category": "creative", "max_tokens": 256} -{"text": "List the planets of the solar system in order from the Sun.", "category": "factual", "max_tokens": 128} -{"text": "Write a limerick about a programmer.", "category": "creative", "max_tokens": 128} -{"text": "What came first, the chicken or the egg? Give a biological answer.", "category": "explanation", "max_tokens": 256} -{"text": "Name 10 words that are their own antonyms (contronyms).", "category": "list", "max_tokens": 256} -{"text": "Explain how a refrigerator works.", "category": "explanation", "max_tokens": 256} -{"text": "Write a dialogue between a cat and a dog meeting for the first time.", "category": "creative", "max_tokens": 256} -{"text": "What are the pros and cons of nuclear energy?", "category": "list", "max_tokens": 512} -{"text": "Describe the color blue to a blind person.", "category": "creative", "max_tokens": 256} -{"text": "List 5 historical events that happened on March 15.", "category": "factual", "max_tokens": 256} -{"text": "Write a tweet-length summary of World War II.", "category": "creative", "max_tokens": 128} -{"text": "Explain what a blockchain is without using technical jargon.", "category": "explanation", "max_tokens": 256} -{"text": "Generate a random D&D character: race, class, background, and one personality trait.", "category": "creative", "max_tokens": 256} -{"text": "What would you name a coffee shop on the Moon?", "category": "creative", "max_tokens": 128} -{"text": "Explain the trolley problem and give your analysis.", "category": "explanation", "max_tokens": 512} -{"text": "List 8 things you can do with a brick besides building.", "category": "list", "max_tokens": 256} -{"text": "Write an acrostic poem where the first letters spell PYTHON.", "category": "creative", "max_tokens": 256} -{"text": "Explain how GPS works in 3 sentences.", "category": "explanation", "max_tokens": 128} -{"text": "Describe the internet to someone from the year 1900.", "category": "creative", "max_tokens": 256} -{"text": "What are 5 cognitive biases that affect decision-making?", "category": "list", "max_tokens": 256} -{"text": "Write a fortune cookie message.", "category": "creative", "max_tokens": 64} -{"text": "Explain the Monty Hall problem and the correct strategy.", "category": "explanation", "max_tokens": 512} -{"text": "If you could add one amendment to the US Constitution, what would it be and why?", "category": "creative", "max_tokens": 256} -{"text": "List all the US states that border the Pacific Ocean.", "category": "factual", "max_tokens": 128} -{"text": "Write a one-paragraph pitch for a startup that sells AI-generated perfumes.", "category": "creative", "max_tokens": 256} -{"text": "Explain the difference between machine learning, deep learning, and AI.", "category": "explanation", "max_tokens": 256} -{"text": "Name 6 foods that are technically berries (that most people don't realize).", "category": "factual", "max_tokens": 256} -{"text": "What are 4 things that would be different if humans had evolved with tails?", "category": "creative", "max_tokens": 256} -{"text": "Write instructions for making a peanut butter and jelly sandwich so precisely that a robot could follow them.", "category": "explanation", "max_tokens": 512} -{"text": "Generate 5 band names that don't exist yet.", "category": "creative", "max_tokens": 128} -{"text": "Explain what causes deja vu.", "category": "explanation", "max_tokens": 256} -{"text": "List the 5 largest deserts in the world by area.", "category": "factual", "max_tokens": 128} -{"text": "Write a review of a restaurant that only serves food from the year 2200.", "category": "creative", "max_tokens": 256} -{"text": "Explain why we dream.", "category": "explanation", "max_tokens": 256} -{"text": "What are 3 things that are true about octopuses that sound made up?", "category": "factual", "max_tokens": 256} -{"text": "Generate a plot summary for a movie where time runs backwards.", "category": "creative", "max_tokens": 256} -{"text": "Explain the ship of Theseus paradox.", "category": "explanation", "max_tokens": 256} -{"text": "List 5 inventions that were discovered by accident.", "category": "factual", "max_tokens": 256} diff --git a/prompt-suites/long-context.jsonl b/prompt-suites/long-context.jsonl deleted file mode 100644 index aa94c10..0000000 --- a/prompt-suites/long-context.jsonl +++ /dev/null @@ -1,10 +0,0 @@ -{"text": "Here is a list of fictional employees and their departments:\n\nAlice Chen - Engineering (joined 2019, salary $125,000, office 4A)\nBob Martinez - Marketing (joined 2020, salary $95,000, office 2B)\nCarol Williams - Engineering (joined 2018, salary $142,000, office 4C)\nDavid Kim - Sales (joined 2021, salary $88,000, office 1A)\nEva Johnson - Engineering (joined 2017, salary $155,000, office 4D)\nFrank Brown - Marketing (joined 2022, salary $82,000, office 2A)\nGrace Lee - Sales (joined 2019, salary $97,000, office 1B)\nHenry Davis - Engineering (joined 2020, salary $131,000, office 4B)\nIris Patel - HR (joined 2018, salary $105,000, office 3A)\nJack Wilson - Sales (joined 2023, salary $78,000, office 1C)\nKaren Thompson - Engineering (joined 2016, salary $168,000, office 4E)\nLiam Garcia - Marketing (joined 2021, salary $89,000, office 2C)\nMona Robinson - HR (joined 2020, salary $98,000, office 3B)\nNate Foster - Engineering (joined 2019, salary $138,000, office 4F)\nOlivia Scott - Sales (joined 2020, salary $92,000, office 1D)\nPeter Adams - Marketing (joined 2019, salary $101,000, office 2D)\nQuinn Harris - Engineering (joined 2022, salary $118,000, office 4G)\nRachel Clark - HR (joined 2021, salary $94,000, office 3C)\nSam Turner - Sales (joined 2018, salary $112,000, office 1E)\nTina Wright - Engineering (joined 2020, salary $129,000, office 4H)\n\nQuestions:\n1. What is the average salary in the Engineering department?\n2. Who joined most recently?\n3. Which department has the highest average salary?\n4. List all employees in office building 4 (offices starting with 4).\n5. What is the total salary expenditure for Sales?", "category": "recall", "max_tokens": 512} -{"text": "Read the following passage carefully. There will be questions at the end.\n\nThe Kepler space telescope, launched by NASA on March 7, 2009, was designed to survey a portion of the Milky Way galaxy to discover Earth-size exoplanets in or near habitable zones and estimate how many of the billions of stars in the Milky Way have such planets. Named after the German astronomer Johannes Kepler, it observed 530,506 stars and detected 2,662 exoplanets over its operational lifetime. The spacecraft operated from an Earth-trailing heliocentric orbit and used a photometer that continuously monitored the brightness of over 150,000 main sequence stars in a fixed field of view.\n\nThe telescope had a 0.95-meter aperture Schmidt camera design with a 105 square degree field of view, much larger than any previous space telescope. Its CCD array contained 42 CCDs with 2200x1024 pixels each, totaling over 95 million pixels. The mission cost approximately $600 million over its lifetime.\n\nKepler detected planets using the transit method, observing tiny periodic dips in stellar brightness when a planet crosses in front of its host star. A typical Earth-like planet transiting a Sun-like star would cause a brightness reduction of about 84 parts per million. The telescope needed to observe at least three transits to confirm a planetary candidate.\n\nThe original mission ended in May 2013 when a second of four reaction wheels failed, preventing precise pointing. NASA repurposed the spacecraft as K2, which observed different patches of sky along the ecliptic plane. K2 discovered an additional 529 confirmed exoplanets before the spacecraft ran out of fuel and was retired on October 30, 2018.\n\nQuestions:\n1. How many stars did Kepler observe in total?\n2. What was the total pixel count of its CCD array?\n3. How many reaction wheels needed to fail before the original mission ended?\n4. How many confirmed exoplanets did K2 (the repurposed mission) discover?\n5. What date was the spacecraft retired?", "category": "recall", "max_tokens": 512} -{"text": "I will give you a sequence of 50 words. Memorize them, then answer questions.\n\nThe words are: telescope, marmalade, algorithm, cathedral, phosphorus, labyrinth, chandelier, porcupine, xylophone, archipelago, tambourine, holographic, cinnamon, thunderstorm, palindrome, eucalyptus, kaleidoscope, watermelon, observatory, chameleon, trampoline, mysterious, dandelion, helicopter, magnificent, strawberry, constellation, turquoise, microphone, caterpillar, avalanche, harmonica, crocodile, tangerine, periscope, butterfly, serendipity, pineapple, rhinoceros, Antarctica, guillotine, pomegranate, cantaloupe, trampoline, wilderness, saxophone, illuminate, centennial, dragonfly, escalator.\n\nQuestions:\n1. What was the 7th word?\n2. What was the 23rd word?\n3. Did the word 'trampoline' appear more than once? If so, at which positions?\n4. What was the last word?\n5. List all words that start with the letter 'c'.", "category": "needle_in_haystack", "max_tokens": 512} -{"text": "Below is a conversation log between a customer and a support agent. Find the key information.\n\n[10:01 AM] Customer: Hi, I'm having trouble with my order #ORD-2024-87432.\n[10:02 AM] Agent: Hello! I'd be happy to help. Let me look up that order.\n[10:03 AM] Agent: I can see your order was placed on February 15th for a Blue Widget XL (SKU: BW-XL-003), quantity 2, at $49.99 each.\n[10:04 AM] Customer: Yes, but I received Red Widgets instead of Blue ones.\n[10:05 AM] Agent: I'm sorry about that. Let me check the warehouse records.\n[10:06 AM] Agent: It appears there was a picking error. The warehouse pulled items from bin B-17 instead of bin B-71.\n[10:07 AM] Customer: Can I get a replacement shipped?\n[10:08 AM] Agent: Absolutely. I'm creating a replacement order now. Your new order number is ORD-2024-87598. It will ship via Priority Express with tracking number 1Z999AA10123456784.\n[10:09 AM] Customer: Great. Do I need to return the wrong items?\n[10:10 AM] Agent: No, please keep the Red Widgets as a courtesy. Your replacement Blue Widgets will arrive within 2-3 business days.\n[10:11 AM] Agent: I've also applied a 15% discount code SORRY15 to your account for the inconvenience.\n[10:12 AM] Customer: Thank you so much! One more thing - my shipping address is 742 Evergreen Terrace, Springfield, IL 62704.\n[10:13 AM] Agent: Confirmed, shipping to that address. Is there anything else I can help with?\n[10:14 AM] Customer: No, that's all. Thanks!\n\nQuestions:\n1. What was the original order number?\n2. What SKU was ordered?\n3. Which bin should the warehouse have pulled from?\n4. What is the tracking number for the replacement?\n5. What discount code was provided?\n6. What is the customer's zip code?", "category": "recall", "max_tokens": 512} -{"text": "Below is a table of chemical elements with their properties. Study it carefully.\n\n| Element | Symbol | Atomic Number | Atomic Mass | Melting Point (°C) | Boiling Point (°C) | Density (g/cm³) | Year Discovered |\n|---------|--------|---------------|-------------|-------------------|--------------------|-----------------|-----------------|\n| Hydrogen | H | 1 | 1.008 | -259.16 | -252.87 | 0.00008988 | 1766 |\n| Helium | He | 2 | 4.003 | -272.20 | -268.93 | 0.0001785 | 1868 |\n| Lithium | Li | 3 | 6.941 | 180.54 | 1342 | 0.534 | 1817 |\n| Carbon | C | 6 | 12.011 | 3550 | 4027 | 2.267 | ancient |\n| Nitrogen | N | 7 | 14.007 | -210.00 | -195.79 | 0.0012506 | 1772 |\n| Oxygen | O | 8 | 15.999 | -218.79 | -182.96 | 0.001429 | 1774 |\n| Iron | Fe | 26 | 55.845 | 1538 | 2862 | 7.874 | ancient |\n| Copper | Cu | 29 | 63.546 | 1084.62 | 2562 | 8.96 | ancient |\n| Silver | Ag | 47 | 107.868 | 961.78 | 2162 | 10.49 | ancient |\n| Gold | Au | 79 | 196.967 | 1064.18 | 2856 | 19.3 | ancient |\n| Mercury | Hg | 80 | 200.592 | -38.83 | 356.73 | 13.534 | ancient |\n| Uranium | U | 92 | 238.029 | 1132.2 | 4131 | 19.1 | 1789 |\n\nQuestions:\n1. Which element has the lowest melting point?\n2. What is the density of gold?\n3. Which elements were discovered in the 18th century (1700s)?\n4. Which element has the highest boiling point?\n5. What is the atomic mass of silver?", "category": "recall", "max_tokens": 512} -{"text": "Summarize the following technical specification in exactly 3 bullet points:\n\nThe XR-7000 Industrial Controller supports up to 256 discrete I/O channels, 64 analog input channels (16-bit resolution, ±10V range, 100kHz aggregate sampling rate), and 32 analog output channels (14-bit resolution, ±10V range). Communication interfaces include dual Gigabit Ethernet ports with EtherCAT support (cycle time down to 125μs), one RS-485 serial port (up to 12Mbps), and an optional CAN bus module (CAN 2.0B, up to 1Mbps). The controller runs a real-time Linux kernel (PREEMPT_RT patch) with guaranteed worst-case interrupt latency of 50μs. Internal storage consists of 32GB eMMC for the OS and application code, plus a removable SD card slot (up to 512GB) for data logging. The unit operates from 24VDC (±20%) input power, consuming 45W typical and 78W maximum. Operating temperature range is -20°C to +60°C with no derating. The chassis is IP67 rated and weighs 2.8kg. MTBF is rated at 150,000 hours per MIL-HDBK-217F.", "category": "summarization", "max_tokens": 256} -{"text": "The following paragraph contains exactly ONE false statement. Identify it.\n\nThe Great Wall of China stretches over 13,000 miles and was built over many centuries beginning in the 7th century BC. The wall is visible from the International Space Station with the naked eye under favorable conditions. The primary purpose of the wall was military defense against invasions from the north. Construction materials varied by region, including tamped earth, brick, stone, and wood. The wall was designated a UNESCO World Heritage Site in 1987. During the Ming Dynasty (1368-1644), the wall underwent significant reconstruction and expansion. Several sections of the wall are now popular tourist destinations, with the Badaling section near Beijing being the most visited.", "category": "needle_in_haystack", "max_tokens": 256} -{"text": "You are given a nested data structure. Answer questions about it.\n\n```\ncompany:\n name: TechCorp Industries\n founded: 2005\n ceo: Sarah Mitchell\n departments:\n - name: Engineering\n head: James Park\n teams:\n - name: Backend\n lead: Alice Wang\n members: 12\n tech_stack: [Python, Go, PostgreSQL]\n - name: Frontend\n lead: Bob Lee\n members: 8\n tech_stack: [TypeScript, React, Next.js]\n - name: ML/AI\n lead: Carol Zhang\n members: 6\n tech_stack: [Python, PyTorch, MLX]\n - name: Product\n head: Diana Ross\n teams:\n - name: Growth\n lead: Eric Kim\n members: 4\n - name: Platform\n lead: Fiona Chen\n members: 5\n - name: Operations\n head: George Liu\n budget: 2500000\n teams:\n - name: DevOps\n lead: Henry Wu\n members: 3\n tech_stack: [Kubernetes, Terraform, AWS]\n - name: Security\n lead: Irene Tanaka\n members: 2\n```\n\nQuestions:\n1. How many total members are in the Engineering department?\n2. Who is the head of the Product department?\n3. Which team uses PyTorch?\n4. What is the Operations department budget?\n5. How many teams are there across all departments?", "category": "recall", "max_tokens": 512} -{"text": "Read this meeting transcript and extract the action items with owners and deadlines.\n\n[Meeting: Q1 Planning - January 8, 2026]\n\nSarah: Let's review the deliverables. Mike, where are we on the API redesign?\n\nMike: We finished the design doc. Implementation starts next week. I'll have the auth endpoints done by January 22nd and the data endpoints by February 5th.\n\nSarah: Good. Lisa, what about the mobile app update?\n\nLisa: The redesigned onboarding flow is 80% done. I need design assets from Tom's team. Tom, can you get those to me by January 15th?\n\nTom: I'll prioritize it. The full design system update should be done by end of January. I also need to schedule a review with the accessibility consultant - I'll book that for the week of January 20th.\n\nSarah: Perfect. Chris, the infrastructure migration?\n\nChris: Database migration to the new cluster is scheduled for February 1st. I need Mike to freeze schema changes by January 25th. Also, I'll set up the staging environment by January 18th for testing.\n\nSarah: Great. I'll prepare the board presentation by January 30th. Let's reconvene on January 22nd to check progress.\n\nExtract all action items in this format: [Owner] - [Task] - [Deadline]", "category": "recall", "max_tokens": 512} -{"text": "Below is a passage with key numbers embedded throughout. Read it and answer the numerical questions at the end.\n\nThe city of Avalonia was founded in 1847 by 342 settlers who arrived on 7 ships. By 1900, the population had grown to 15,847. The city's main bridge, completed in 1923, spans 1,247 feet across the Emerald River and cost $3.2 million to build (approximately $56.8 million in today's dollars). The bridge has 4 main support pillars, each sunk 89 feet into the riverbed.\n\nThe Avalonia Public Library, established in 1891 with a donation of $175,000 from industrialist Marcus Webb, currently houses 847,000 volumes across 6 floors. The building was renovated in 1967, 2001, and most recently in 2019 at a cost of $12.4 million. The library employs 156 staff members and serves approximately 23,000 active cardholders.\n\nThe city's annual Harvest Festival, first held in 1853, attracts an average of 45,000 visitors over its 3-day duration. The festival features 127 food vendors, 34 live music acts, and generates approximately $2.7 million in revenue for local businesses. The festival's famous pumpkin contest has been won 8 times by the Henderson family.\n\nQuestions:\n1. How many ships brought the original settlers?\n2. What was the bridge construction cost in today's dollars?\n3. How deep are the bridge support pillars?\n4. When was the library's most recent renovation?\n5. How much revenue does the Harvest Festival generate?", "category": "needle_in_haystack", "max_tokens": 512} diff --git a/prompt-suites/reasoning.jsonl b/prompt-suites/reasoning.jsonl deleted file mode 100644 index 31d46fe..0000000 --- a/prompt-suites/reasoning.jsonl +++ /dev/null @@ -1,50 +0,0 @@ -{"text": "What is 47 * 83 + 156 / 12?", "category": "arithmetic", "max_tokens": 256} -{"text": "If I have 3/7 of a pizza and eat 1/3 of what I have, what fraction of the whole pizza did I eat?", "category": "arithmetic", "max_tokens": 256} -{"text": "What is the square root of 2 to 10 decimal places?", "category": "precision", "max_tokens": 256} -{"text": "Calculate: 15! (fifteen factorial)", "category": "arithmetic", "max_tokens": 256} -{"text": "What is 2^64?", "category": "large_numbers", "max_tokens": 256} -{"text": "A train leaves Chicago at 2:15 PM traveling at 80 mph. Another leaves Detroit (280 miles away) at 2:45 PM traveling at 95 mph toward Chicago. At what time do they meet?", "category": "word_problem", "max_tokens": 512} -{"text": "Is 7919 a prime number? Show your reasoning.", "category": "logic", "max_tokens": 512} -{"text": "If all bloops are razzles, and all razzles are lazzles, are all bloops definitely lazzles?", "category": "logic", "max_tokens": 256} -{"text": "Three people check into a hotel room that costs $30. They each pay $10. Later the manager realizes the room is only $25 and gives $5 to the bellboy to return. The bellboy keeps $2 and gives $1 back to each person. Now each person paid $9 (total $27), plus the bellboy has $2 = $29. Where is the missing dollar?", "category": "logic", "max_tokens": 512} -{"text": "Convert the hexadecimal number 0xDEADBEEF to decimal.", "category": "arithmetic", "max_tokens": 256} -{"text": "What is the sum of all integers from 1 to 1000?", "category": "arithmetic", "max_tokens": 256} -{"text": "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost?", "category": "logic", "max_tokens": 256} -{"text": "If you flip a fair coin 10 times and get heads every time, what is the probability the next flip is heads?", "category": "logic", "max_tokens": 256} -{"text": "How many times does the digit 7 appear in all numbers from 1 to 1000?", "category": "chain_of_thought", "max_tokens": 512} -{"text": "A farmer has 17 sheep. All but 9 die. How many sheep are left?", "category": "logic", "max_tokens": 256} -{"text": "What is 0.1 + 0.2? Is it exactly 0.3? Explain.", "category": "precision", "max_tokens": 256} -{"text": "Solve step by step: 8 ÷ 2(2+2) = ?", "category": "arithmetic", "max_tokens": 512} -{"text": "I have 12 coins. One is counterfeit and lighter than the rest. Using a balance scale, what is the minimum number of weighings needed to find the counterfeit coin?", "category": "logic", "max_tokens": 512} -{"text": "What is the derivative of x^x with respect to x?", "category": "chain_of_thought", "max_tokens": 512} -{"text": "A snail climbs 3 feet up a well each day but slides back 2 feet each night. If the well is 30 feet deep, how many days does it take to escape?", "category": "word_problem", "max_tokens": 512} -{"text": "What is the 50th term of the Fibonacci sequence?", "category": "arithmetic", "max_tokens": 256} -{"text": "If it takes 5 machines 5 minutes to make 5 widgets, how long would it take 100 machines to make 100 widgets?", "category": "logic", "max_tokens": 256} -{"text": "Evaluate the integral of 1/(1+x^2) from 0 to infinity.", "category": "chain_of_thought", "max_tokens": 512} -{"text": "There are 100 lockers in a row, all closed. 100 students walk by. Student 1 opens every locker. Student 2 toggles every 2nd locker. Student 3 toggles every 3rd, and so on. After all 100 students pass, which lockers are open?", "category": "chain_of_thought", "max_tokens": 512} -{"text": "Express 0.999... (repeating) as a fraction. Is it equal to 1?", "category": "logic", "max_tokens": 256} -{"text": "What is 999999 * 999999?", "category": "arithmetic", "max_tokens": 256} -{"text": "A car travels the first half of a distance at 40 mph and the second half at 60 mph. What is the average speed for the whole trip?", "category": "word_problem", "max_tokens": 256} -{"text": "In a room of 23 people, what is the probability that at least two share a birthday?", "category": "chain_of_thought", "max_tokens": 512} -{"text": "What is the largest prime number less than 10000?", "category": "arithmetic", "max_tokens": 256} -{"text": "You have two ropes. Each takes exactly 1 hour to burn, but they burn non-uniformly. How can you measure exactly 45 minutes?", "category": "logic", "max_tokens": 512} -{"text": "Simplify: (x^2 - 9) / (x^2 - 6x + 9)", "category": "chain_of_thought", "max_tokens": 256} -{"text": "How many zeros are at the end of 100! (100 factorial)?", "category": "chain_of_thought", "max_tokens": 512} -{"text": "Three friends split a dinner bill. The bill is $47.50 plus 20% tip, split evenly. How much does each person pay?", "category": "word_problem", "max_tokens": 256} -{"text": "What is the angle between the hour and minute hands of a clock at 3:15?", "category": "word_problem", "max_tokens": 256} -{"text": "Is the statement 'this sentence is false' true or false? Explain the paradox.", "category": "logic", "max_tokens": 512} -{"text": "Convert 72 degrees Fahrenheit to Celsius, then to Kelvin.", "category": "arithmetic", "max_tokens": 256} -{"text": "A lily pad doubles in size every day. If it takes 48 days to cover a pond, how many days does it take to cover half the pond?", "category": "logic", "max_tokens": 256} -{"text": "Calculate the determinant of the matrix [[3,1,4],[1,5,9],[2,6,5]].", "category": "chain_of_thought", "max_tokens": 512} -{"text": "What is the probability of getting exactly 3 heads in 5 fair coin flips?", "category": "chain_of_thought", "max_tokens": 256} -{"text": "If log_2(x) = 5, what is log_4(x)?", "category": "arithmetic", "max_tokens": 256} -{"text": "You have a 3-gallon jug and a 5-gallon jug. How do you measure exactly 4 gallons?", "category": "logic", "max_tokens": 512} -{"text": "What is the GCD of 1071 and 462? Show your work using the Euclidean algorithm.", "category": "chain_of_thought", "max_tokens": 512} -{"text": "A store marks up items 40% over cost, then offers a 25% discount. What is the net percentage change from the original cost?", "category": "word_problem", "max_tokens": 256} -{"text": "How many distinct ways can you arrange the letters in MISSISSIPPI?", "category": "chain_of_thought", "max_tokens": 512} -{"text": "Solve: |2x - 5| < 3. Express the solution as an interval.", "category": "arithmetic", "max_tokens": 256} -{"text": "If you pick 2 cards from a standard deck without replacement, what is the probability both are aces?", "category": "chain_of_thought", "max_tokens": 256} -{"text": "A projectile is launched at 45 degrees with initial velocity 20 m/s. Ignoring air resistance, what is the maximum height? (g = 9.8 m/s²)", "category": "word_problem", "max_tokens": 512} -{"text": "What is the remainder when 2^1000 is divided by 7?", "category": "chain_of_thought", "max_tokens": 512} -{"text": "Three cards are placed face down: a king, a queen, and a jack. You pick one. The dealer reveals one of the remaining cards is a jack. Should you switch your choice?", "category": "logic", "max_tokens": 512} -{"text": "Compute the cross product of vectors [1,2,3] and [4,5,6].", "category": "arithmetic", "max_tokens": 256} diff --git a/pyproject.toml b/pyproject.toml index 6ac7aa5..aac2ab4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,12 +61,15 @@ dev = [ [tool.ruff] target-version = "py313" -line-length = 100 +line-length = 120 src = ["src"] [tool.ruff.lint] select = ["E", "F", "I", "N", "W", "UP", "B", "SIM"] +[tool.ruff.lint.per-file-ignores] +"html.py" = ["E501"] + [tool.mypy] python_version = "3.13" strict = true diff --git a/src/infer_check/analysis/__init__.py b/src/infer_check/analysis/__init__.py index 669f392..f4f30a2 100644 --- a/src/infer_check/analysis/__init__.py +++ b/src/infer_check/analysis/__init__.py @@ -1,3 +1,19 @@ # Analysis modules for infer-check. # These modules operate on InferenceResult and ComparisonResult objects # and never call backends directly. + +from infer_check.analysis.answer_extract import ( + ExtractedAnswer, + FlipDetail, + answers_match, + compute_flip_rate, + extract_answer, +) + +__all__ = [ + "ExtractedAnswer", + "FlipDetail", + "answers_match", + "compute_flip_rate", + "extract_answer", +] diff --git a/src/infer_check/analysis/answer_extract.py b/src/infer_check/analysis/answer_extract.py new file mode 100644 index 0000000..b23b227 --- /dev/null +++ b/src/infer_check/analysis/answer_extract.py @@ -0,0 +1,395 @@ +"""Answer extraction and flip-rate computation. + +Extracts the "functional answer" from LLM outputs so that two responses +can be compared semantically rather than by raw text similarity. A +*flip* occurs when two models give substantively different answers to the +same prompt — not merely different wording. + +Extraction strategies are selected by prompt category: + + - **numeric** (arithmetic, precision, large_numbers, floating_point, + underflow, formatting, word_problem, multi_digit_arithmetic, + precision_numerics, large_number_reasoning, algebraic_reasoning, + logical_puzzle): extract the last number/expression. + - **boolean** (logic, edge_case): extract yes / no / true / false. + - **code** (python, debugging, completion, precise_syntax, + code_translation): extract fenced code blocks + and compare after whitespace normalisation. + - **json** (json): parse and compare structurally. + - **fallback**: character-level similarity via ``difflib``. +""" + +from __future__ import annotations + +import difflib +import json as json_mod +import re +from collections.abc import Sequence +from dataclasses import dataclass + +__all__ = [ + "ExtractedAnswer", + "extract_answer", + "answers_match", + "compute_flip_rate", +] + + +# ── Categories → extraction strategy ──────────────────────────────── + +_NUMERIC_CATEGORIES = frozenset( + { + "arithmetic", + "precision", + "large_numbers", + "floating_point", + "underflow", + "formatting", + "word_problem", + "multi_digit_arithmetic", + "precision_numerics", + "large_number_reasoning", + "algebraic_reasoning", + "logical_puzzle", + } +) + +_BOOLEAN_CATEGORIES = frozenset( + { + "logic", + "edge_case", + } +) + +_CODE_CATEGORIES = frozenset( + { + "python", + "debugging", + "completion", + "precise_syntax", + "code_translation", + } +) + +_JSON_CATEGORIES = frozenset( + { + "json", + } +) + + +# ── Regex patterns ────────────────────────────────────────────────── + +# Matches integers, decimals, scientific notation, and comma-separated +# numbers like 1,234,567.89. Also matches negative numbers. +_NUMBER_RE = re.compile( + r"-?(?:\d{1,3}(?:,\d{3})*|\d+)(?:\.\d+)?(?:[eE][+-]?\d+)?", +) + +# Fenced code block: ```...``` (with optional language tag). +_CODE_BLOCK_RE = re.compile( + r"```(?:\w+)?\s*\n(.*?)```", + re.DOTALL, +) + +# Boolean-ish final answer patterns. +_BOOLEAN_RE = re.compile( + r"\b(yes|no|true|false|correct|incorrect|definitely|not necessarily)\b", + re.IGNORECASE, +) + + +@dataclass(frozen=True, slots=True) +class ExtractedAnswer: + """The extracted functional answer from an LLM response.""" + + strategy: str # "numeric", "boolean", "code", "json", "raw" + value: str # normalised answer string for comparison + raw_text: str # original full response text + confidence: float # 0–1, how confident we are in extraction + + +# ── Extraction helpers ────────────────────────────────────────────── + + +def _extract_numeric(text: str) -> ExtractedAnswer: + """Pull the last number from a math/numeric response.""" + matches = _NUMBER_RE.findall(text) + if not matches: + return ExtractedAnswer( + strategy="numeric", + value="", + raw_text=text, + confidence=0.0, + ) + # Take the last number — models tend to state the final answer last. + last = matches[-1].replace(",", "") + return ExtractedAnswer( + strategy="numeric", + value=last, + raw_text=text, + confidence=0.9, + ) + + +def _extract_boolean(text: str) -> ExtractedAnswer: + """Pull the boolean conclusion from a logic response. + + Scans for yes/no/true/false keywords. When multiple are found, + the *last* one wins (models often hedge before concluding). + Negation context is checked: "not correct" → no, "not true" → no. + """ + matches: list[tuple[int, str]] = [(m.start(), m.group(0).lower()) for m in _BOOLEAN_RE.finditer(text)] + if not matches: + return ExtractedAnswer( + strategy="boolean", + value="", + raw_text=text, + confidence=0.0, + ) + + _pos = {"yes", "correct", "definitely", "true"} + _neg = {"no", "incorrect", "not necessarily", "false"} + + pos, raw = matches[-1] + text_lower = text.lower() + + # Check for "not" within 5 chars before a positive keyword. + negated = False + if raw in _pos: + prefix = text_lower[max(0, pos - 5) : pos] + if "not" in prefix.split(): + negated = True + + if raw in _neg or negated: + normalised = "no" + elif raw in _pos: + normalised = "yes" + else: + normalised = raw + + return ExtractedAnswer( + strategy="boolean", + value=normalised, + raw_text=text, + confidence=0.85, + ) + + +def _extract_code(text: str) -> ExtractedAnswer: + """Extract fenced code blocks and normalise whitespace.""" + blocks = _CODE_BLOCK_RE.findall(text) + if not blocks: + # No fenced block — treat entire response as code (common for + # models that don't fence). Strip leading prose heuristically. + lines = text.strip().splitlines() + code_lines = [] + in_code = False + for line in lines: + stripped = line.strip() + # Heuristic: code lines start with def, class, import, + # return, if, for, while, #, or are indented. + if ( + in_code + or stripped.startswith( + ( + "def ", + "class ", + "import ", + "from ", + "return ", + "if ", + "for ", + "while ", + "#", + " ", + "\t", + ) + ) + or stripped == "" + ): + code_lines.append(line) + in_code = True + code = "\n".join(code_lines).strip() + confidence = 0.5 if code else 0.0 + else: + code = "\n\n".join(b.strip() for b in blocks) + confidence = 0.9 + + # Normalise: collapse whitespace runs, strip trailing ws per line. + normalised = "\n".join(line.rstrip() for line in code.splitlines()).strip() + return ExtractedAnswer( + strategy="code", + value=normalised, + raw_text=text, + confidence=confidence, + ) + + +def _extract_json(text: str) -> ExtractedAnswer: + """Extract and canonicalise JSON from the response.""" + # Try to find a JSON block in fences first. + blocks = _CODE_BLOCK_RE.findall(text) + candidates = blocks if blocks else [text] + + for candidate in candidates: + candidate = candidate.strip() + # Find the first { or [ and try to parse from there. + for start_char in ("{", "["): + idx = candidate.find(start_char) + if idx == -1: + continue + try: + parsed = json_mod.loads(candidate[idx:]) + canonical = json_mod.dumps( + parsed, + sort_keys=True, + separators=(",", ":"), + ) + return ExtractedAnswer( + strategy="json", + value=canonical, + raw_text=text, + confidence=0.95, + ) + except json_mod.JSONDecodeError: + continue + + return ExtractedAnswer( + strategy="json", + value="", + raw_text=text, + confidence=0.0, + ) + + +def _extract_raw(text: str) -> ExtractedAnswer: + """Fallback: use the full text, lightly normalised.""" + normalised = " ".join(text.lower().split()) + return ExtractedAnswer( + strategy="raw", + value=normalised, + raw_text=text, + confidence=0.3, + ) + + +# ── Public API ────────────────────────────────────────────────────── + + +def extract_answer(text: str, category: str = "general") -> ExtractedAnswer: + """Extract the functional answer from an LLM response. + + Selects an extraction strategy based on the prompt category. + + Args: + text: The full LLM response text. + category: The prompt category (from ``Prompt.category``). + + Returns: + An ``ExtractedAnswer`` with the normalised value and metadata. + """ + cat = category.lower() + if cat in _NUMERIC_CATEGORIES: + return _extract_numeric(text) + if cat in _BOOLEAN_CATEGORIES: + return _extract_boolean(text) + if cat in _CODE_CATEGORIES: + return _extract_code(text) + if cat in _JSON_CATEGORIES: + return _extract_json(text) + return _extract_raw(text) + + +def answers_match( + a: ExtractedAnswer, + b: ExtractedAnswer, + *, + similarity_threshold: float = 0.85, +) -> bool: + """Determine whether two extracted answers are functionally equivalent. + + For numeric, boolean, and json strategies the comparison is exact + (after normalisation). For code and raw strategies, a similarity + threshold is used. + + Args: + a: First extracted answer. + b: Second extracted answer. + similarity_threshold: Minimum ``SequenceMatcher`` ratio for + code/raw comparisons to be considered a match. + + Returns: + ``True`` if the answers are functionally equivalent. + """ + # If either extraction failed, fall back to raw similarity. + if not a.value or not b.value: + ratio = difflib.SequenceMatcher(None, a.value or "", b.value or "").ratio() + return ratio >= similarity_threshold + + strategy = a.strategy # both should share strategy if same prompt + + if strategy in ("numeric", "boolean", "json"): + return a.value == b.value + + # Code and raw: use sequence similarity. + ratio = difflib.SequenceMatcher(None, a.value, b.value).ratio() + return ratio >= similarity_threshold + + +@dataclass(frozen=True, slots=True) +class FlipDetail: + """Per-prompt flip analysis result.""" + + prompt_id: str + category: str + flipped: bool + answer_a: ExtractedAnswer + answer_b: ExtractedAnswer + + +def compute_flip_rate( + pairs: Sequence[tuple[str, str, str, str]], + *, + similarity_threshold: float = 0.85, +) -> tuple[float, list[FlipDetail]]: + """Compute the flip rate across a set of response pairs. + + Args: + pairs: Sequence of ``(prompt_id, category, text_a, text_b)`` + tuples. + similarity_threshold: Passed through to ``answers_match`` for + code/raw comparisons. + + Returns: + A tuple of ``(flip_rate, details)`` where ``flip_rate`` is in + [0, 1] and ``details`` is a list of per-prompt ``FlipDetail`` + objects. + """ + if not pairs: + return 0.0, [] + + details: list[FlipDetail] = [] + flip_count = 0 + + for prompt_id, category, text_a, text_b in pairs: + ans_a = extract_answer(text_a, category) + ans_b = extract_answer(text_b, category) + flipped = not answers_match( + ans_a, + ans_b, + similarity_threshold=similarity_threshold, + ) + if flipped: + flip_count += 1 + details.append( + FlipDetail( + prompt_id=prompt_id, + category=category, + flipped=flipped, + answer_a=ans_a, + answer_b=ans_b, + ) + ) + + return flip_count / len(pairs), details diff --git a/src/infer_check/backends/base.py b/src/infer_check/backends/base.py index 641e8ee..f4b062b 100644 --- a/src/infer_check/backends/base.py +++ b/src/infer_check/backends/base.py @@ -5,7 +5,7 @@ from infer_check.types import InferenceResult, Prompt -__all__ = ["BackendAdapter", "BackendConfig", "get_backend"] +__all__ = ["BackendAdapter", "BackendConfig", "get_backend", "get_backend_for_model"] class BackendAdapter(Protocol): @@ -73,8 +73,7 @@ def get_backend(config: BackendConfig) -> BackendAdapter: if not config.base_url: raise ValueError( - "openai-compat backend requires --base-url. " - "Example: --base-url http://localhost:11434/v1 (Ollama)" + "openai-compat backend requires --base-url. Example: --base-url http://localhost:11434/v1 (Ollama)" ) return OpenAICompatBackend( base_url=config.base_url, @@ -85,3 +84,28 @@ def get_backend(config: BackendConfig) -> BackendAdapter: else: supported = ", ".join(["mlx-lm", "llama-cpp", "vllm-mlx", "openai-compat"]) raise ValueError(f"Unknown backend type: '{config.backend_type}'. Supported: {supported}") + + +def get_backend_for_model( + model_str: str, + backend_type: str | None = None, + base_url: str | None = None, + quantization: str | None = None, +) -> BackendAdapter: + """Resolve model string to a backend and instantiate it. + + If backend_type is provided, it forces that backend. Otherwise, it resolves + based on the model string. + """ + from infer_check.resolve import resolve_model + + # Always normalize the model string first to ensure consistent model_id/base_url + resolved = resolve_model(model_str, base_url=base_url) + config = BackendConfig( + backend_type=backend_type or resolved.backend, # type: ignore + model_id=resolved.model_id, + base_url=resolved.base_url, + quantization=quantization or resolved.label, + ) + + return get_backend(config) diff --git a/src/infer_check/backends/llama_cpp.py b/src/infer_check/backends/llama_cpp.py index f9d48a0..6635fc6 100644 --- a/src/infer_check/backends/llama_cpp.py +++ b/src/infer_check/backends/llama_cpp.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import math import time import httpx @@ -36,6 +37,7 @@ async def generate(self, prompt: Prompt) -> InferenceResult: "prompt": prompt.text, "n_predict": prompt.max_tokens, "temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0, + "n_probs": 10, # Request top 10 probabilities for KL divergence } start = time.perf_counter() @@ -72,17 +74,38 @@ async def generate(self, prompt: Prompt) -> InferenceResult: # Extract per-token data ------------------------------------------ tokens: list[str] = [] logprobs: list[float] | None = None + distributions: list[list[float]] | None = None + distribution_metadata: list[dict[str, int | str]] | None = None completion_probs = data.get("completion_probabilities") if completion_probs: logprobs = [] + distributions = [] + distribution_metadata = [] for entry in completion_probs: tok_str = entry.get("content", "") tokens.append(tok_str) - # Top prob entry (index 0) contains the chosen token logprob. + # Top prob entry (index 0) contains the chosen token linear prob. probs = entry.get("probs", []) if probs: - logprobs.append(float(probs[0].get("prob", 0.0))) + # llama-server returns linear probabilities (0..1). + # We convert to log-probabilities (log-space) to match the rest of the codebase. + # Epsilon 1e-10 matches kl_divergence epsilon in the analyzer. + epsilon = 1e-10 + logprobs.append(math.log(max(float(probs[0].get("prob", 0.0)), epsilon))) + + dist_logprobs = [] + for p in probs: + p_val = max(float(p.get("prob", 0.0)), epsilon) + dist_logprobs.append(math.log(p_val)) + distributions.append(dist_logprobs) + + # Store token IDs to allow alignment if needed. + dist_meta: dict[str, int | str] = {} + for i, p in enumerate(probs): + if "id" in p: + dist_meta[f"id_{i}"] = int(p["id"]) + distribution_metadata.append(dist_meta) else: tokens = content.split() @@ -98,6 +121,8 @@ async def generate(self, prompt: Prompt) -> InferenceResult: model_id=data.get("model", "unknown"), tokens=tokens, logprobs=logprobs, + distributions=distributions, + distribution_metadata=distribution_metadata, text=content, latency_ms=elapsed_s * 1000, tokens_per_second=tps, diff --git a/src/infer_check/backends/mlx_lm.py b/src/infer_check/backends/mlx_lm.py index 6d03202..6996f04 100644 --- a/src/infer_check/backends/mlx_lm.py +++ b/src/infer_check/backends/mlx_lm.py @@ -4,7 +4,7 @@ import gc import time -from typing import Any +from typing import Any, cast from infer_check.types import InferenceResult, Prompt @@ -38,6 +38,9 @@ async def generate(self, prompt: Prompt) -> InferenceResult: Uses ``mlx_lm.generate_step`` when available so that per-token logprobs can be captured. Falls back to the simpler ``mlx_lm.generate`` otherwise. + + The actual computation is synchronous (MLX is single-threaded), + but the method is async to satisfy the ``BackendAdapter`` protocol. """ self._ensure_loaded() @@ -51,16 +54,16 @@ async def generate(self, prompt: Prompt) -> InferenceResult: return self._generate_simple(prompt) except Exception as inner: raise RuntimeError( - f"MLX generation failed for prompt '{prompt.text[:80]}...'\n" - f"Model: {self._model_id}\n" - f"Error: {inner}" + f"MLX generation failed for prompt '{prompt.text[:80]}...'\nModel: {self._model_id}\nError: {inner}" ) from inner async def generate_batch(self, prompts: list[Prompt]) -> list[InferenceResult]: - """Generate inference results for a batch of prompts.""" - import asyncio + """Generate inference results for a batch of prompts. - return list(await asyncio.gather(*(self.generate(p) for p in prompts))) + MLX is single-threaded so we run sequentially rather than + using ``asyncio.gather`` which would not yield parallelism. + """ + return [await self.generate(p) for p in prompts] async def health_check(self) -> bool: """Load the model and generate a single token.""" @@ -96,17 +99,14 @@ def _ensure_loaded(self) -> None: try: from mlx_lm import load except ImportError: - raise RuntimeError( - "mlx-lm not installed. Install with: pip install infer-check[mlx]" - ) from None + raise RuntimeError("mlx-lm not installed. Install with: pip install infer-check[mlx]") from None from pathlib import Path model_path = Path(self._model_id).expanduser() if model_path.is_absolute() and not model_path.exists(): raise FileNotFoundError( - f"Model path does not exist: {model_path}\n" - f"Check the path or use a HuggingFace repo ID instead." + f"Model path does not exist: {model_path}\nCheck the path or use a HuggingFace repo ID instead." ) repo_or_path = str(model_path) if model_path.exists() else self._model_id @@ -137,16 +137,9 @@ def _format_prompt(self, text: str) -> str: Raw prompts sent to Instruct models produce undefined behavior that varies across quantization levels, making comparisons meaningless. """ - if ( - hasattr(self._tokenizer, "apply_chat_template") - and self._tokenizer.chat_template is not None - ): + if hasattr(self._tokenizer, "apply_chat_template") and self._tokenizer.chat_template is not None: messages = [{"role": "user", "content": text}] - return str( - self._tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - ) + return str(self._tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)) return text def _generate_simple(self, prompt: Prompt) -> InferenceResult: @@ -193,12 +186,17 @@ def _generate_with_logprobs(self, prompt: Prompt) -> InferenceResult: formatted = self._format_prompt(prompt.text) input_ids = self._tokenizer.encode(formatted, return_tensors="mlx") + # Configurable top-K to avoid memory explosion. Default to 10. + top_k = prompt.metadata.get("top_k_logprobs", 10) if prompt.metadata else 10 + tokens: list[str] = [] logprobs: list[float] = [] + distributions: list[list[float]] = [] + distribution_metadata: list[dict[str, int | str]] = [] start = time.perf_counter() - for step_idx, (token, logprob_val) in enumerate( + for step_idx, (token, logprob_dist) in enumerate( generate_step( prompt=input_ids, model=self._model, @@ -208,14 +206,49 @@ def _generate_with_logprobs(self, prompt: Prompt) -> InferenceResult: if step_idx >= prompt.max_tokens: break - token_str = self._tokenizer.decode([token.item()]) + # logprob_dist is an mlx array of full-vocab logprobs. + # We only keep top-K to save memory. + if top_k > 0: + # mx.argpartition is not always available or might be slow for small K. + # Since we need to move to CPU anyway for serialization, we can do it there. + # But to avoid huge tolist(), we can use mx.topk if available or just move a bit. + + # Clamp K to the vocabulary size to avoid out-of-bounds issues. + vocab_size = int(logprob_dist.shape[0]) + if vocab_size <= 0: + # Nothing to record for this step. + continue + effective_top_k = int(top_k) + if effective_top_k < 1: + # Should not happen due to the outer condition, but guard defensively. + continue + if effective_top_k > vocab_size: + effective_top_k = vocab_size + + # Get top-K indices and values + top_k_indices = mx.argpartition(-logprob_dist, effective_top_k - 1)[:effective_top_k] + top_k_values = logprob_dist[top_k_indices] + + # Sort them for consistency + sort_idx = mx.argsort(-top_k_values) + top_k_indices = top_k_indices[sort_idx] + top_k_values = top_k_values[sort_idx] + + dist_list = cast(list[float], top_k_values.tolist()) + dist_indices = cast(list[int], top_k_indices.tolist()) + + distributions.append(dist_list) + meta: dict[str, int | str] = {} + for i, idx in enumerate(dist_indices): + meta[f"id_{i}"] = int(idx) + distribution_metadata.append(meta) + + token_id = int(token.item()) + token_str = self._tokenizer.decode([token_id]) tokens.append(token_str) - # logprob_val may be an mx.array scalar or a float - if hasattr(logprob_val, "item"): - logprobs.append(float(logprob_val.item())) - else: - logprobs.append(float(logprob_val)) + # The logprob of the chosen token (from the full distribution) + logprobs.append(float(logprob_dist[token_id])) elapsed_s = time.perf_counter() - start text = "".join(tokens) @@ -231,6 +264,8 @@ def _generate_with_logprobs(self, prompt: Prompt) -> InferenceResult: quantization=self._quantization, tokens=tokens, logprobs=logprobs if logprobs else None, + distributions=distributions if distributions else None, + distribution_metadata=distribution_metadata if distribution_metadata else None, text=text, latency_ms=elapsed_s * 1000, tokens_per_second=tps, diff --git a/src/infer_check/backends/openai_compat.py b/src/infer_check/backends/openai_compat.py index 87ed4f4..32dd865 100644 --- a/src/infer_check/backends/openai_compat.py +++ b/src/infer_check/backends/openai_compat.py @@ -82,9 +82,7 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult: "Ensure the server is running and the base_url is correct." ) from exc except httpx.TimeoutException as exc: - raise RuntimeError( - f"Request to {self._base_url}/v1/chat/completions timed out after 120s." - ) from exc + raise RuntimeError(f"Request to {self._base_url}/v1/chat/completions timed out after 120s.") from exc except httpx.HTTPStatusError as exc: status = exc.response.status_code body = exc.response.text[:500] @@ -155,8 +153,7 @@ async def _generate_completions(self, prompt: Prompt) -> InferenceResult: ) from exc elif status == 401 or status == 403: raise RuntimeError( - f"Authentication failed at {self._base_url} (HTTP {status}). " - f"Check your API key." + f"Authentication failed at {self._base_url} (HTTP {status}). Check your API key." ) from exc else: raise RuntimeError(f"Server returned HTTP {status}: {body}") from exc @@ -177,14 +174,42 @@ async def _generate_completions(self, prompt: Prompt) -> InferenceResult: # Parse logprobs --------------------------------------------------- tokens: list[str] = [] logprobs_list: list[float] | None = None + distributions: list[list[float]] | None = None + distribution_metadata: list[dict[str, int | str]] | None = None lp_data = choice.get("logprobs") if lp_data and lp_data.get("tokens"): tokens = lp_data["tokens"] raw_logprobs = lp_data.get("token_logprobs", []) - logprobs_list = [ - float(v) if v is not None and not math.isnan(v) else 0.0 for v in raw_logprobs - ] + logprobs_list = [float(v) if v is not None and not math.isnan(v) else 0.0 for v in raw_logprobs] + + top_logprobs = lp_data.get("top_logprobs") + if top_logprobs: + distributions = [] + distribution_metadata = [] + for entry in top_logprobs: + # entry is a dict of token: logprob + # Sort by token string to ensure deterministic order + if not entry: + distributions.append([]) + distribution_metadata.append({}) + continue + sorted_items = sorted(entry.items()) + cleaned_items: list[tuple[str, float]] = [] + for tok, v in sorted_items: + # Mirror token_logprobs sanitization: treat None/NaN/invalid as 0.0 + try: + fv = float(v) if v is not None else float("nan") + except (TypeError, ValueError): + fv = float("nan") + if math.isnan(fv): + fv = 0.0 + cleaned_items.append((tok, fv)) + distributions.append([fv for _, fv in cleaned_items]) + meta: dict[str, int | str] = {} + for i, (tok, _) in enumerate(cleaned_items): + meta[f"id_{i}"] = tok + distribution_metadata.append(meta) else: tokens = text.split() @@ -199,6 +224,8 @@ async def _generate_completions(self, prompt: Prompt) -> InferenceResult: model_id=self._model_id, tokens=tokens, logprobs=logprobs_list, + distributions=distributions, + distribution_metadata=distribution_metadata, text=text, latency_ms=elapsed_s * 1000, tokens_per_second=tps, diff --git a/src/infer_check/cli.py b/src/infer_check/cli.py index 01aa698..2617a21 100644 --- a/src/infer_check/cli.py +++ b/src/infer_check/cli.py @@ -45,7 +45,7 @@ def main() -> None: "4bit=mlx-community/Llama-3.1-8B-Instruct-4bit'" ), ) -@click.option("--backend", default="mlx-lm", show_default=True, help="Backend type.") +@click.option("--backend", default=None, help="Backend type (auto-detected if omitted).") @click.option( "--prompts", required=True, @@ -66,7 +66,7 @@ def main() -> None: @click.option("--base-url", default=None, help="Base URL for HTTP backends.") def sweep( models: str, - backend: str, + backend: str | None, prompts: str, output: Path, baseline: str | None, @@ -83,10 +83,9 @@ def sweep( --models "bf16=mlx-community/Llama-3.1-8B-Instruct-bf16, 4bit=mlx-community/Llama-3.1-8B-Instruct-4bit, 3bit=mlx-community/Llama-3.1-8B-Instruct-3bit" \\ - --backend mlx-lm \\ - --prompts ./prompt-suites/reasoning.jsonl + --prompts reasoning """ - from infer_check.backends.base import BackendConfig, get_backend + from infer_check.backends.base import get_backend_for_model from infer_check.runner import TestRunner from infer_check.suites.loader import load_suite @@ -113,10 +112,7 @@ def sweep( console.print(f"[red]Baseline '{baseline_label}' not found in model map.[/red]") raise SystemExit(1) - console.print( - f"[bold cyan]sweep[/bold cyan] backend={backend} " - f"baseline={baseline_label} models={quant_levels}" - ) + console.print(f"[bold cyan]sweep[/bold cyan] baseline={baseline_label} models={quant_levels}") for label, path in model_map.items(): tag = " (baseline)" if label == baseline_label else "" console.print(f" {label}: {path}{tag}") @@ -126,13 +122,12 @@ def sweep( # Build a separate backend for each model backend_map: dict[str, Any] = {} for label, model_path in model_map.items(): - config = BackendConfig( - backend_type=backend, # type: ignore[arg-type] - model_id=model_path, - quantization=label, + backend_map[label] = get_backend_for_model( + model_str=model_path, + backend_type=backend, base_url=base_url, + quantization=label, ) - backend_map[label] = get_backend(config) runner = TestRunner() result = asyncio.run( @@ -208,6 +203,246 @@ def sweep( console.print(table) +# --------------------------------------------------------------------------- +# compare +# --------------------------------------------------------------------------- + + +@main.command() +@click.argument("model_a") +@click.argument("model_b") +@click.option( + "--prompts", + default="adversarial-numerics", + show_default=True, + help="Bundled suite name (e.g. 'reasoning') or path to a .jsonl file.", +) +@click.option( + "--output", + default="./results/compare/", + show_default=True, + type=click.Path(path_type=Path), + help="Output directory.", +) +@click.option( + "--base-url", + default=None, + help=("Base URL override for HTTP backends. Applied to both models unless they resolve to mlx-lm."), +) +@click.option( + "--label-a", + default=None, + help="Custom label for model A (defaults to auto-derived short name).", +) +@click.option( + "--label-b", + default=None, + help="Custom label for model B (defaults to auto-derived short name).", +) +@click.option( + "--report/--no-report", + default=True, + show_default=True, + help="Generate an HTML comparison report after the run.", +) +def compare( + model_a: str, + model_b: str, + prompts: str, + output: Path, + base_url: str | None, + label_a: str | None, + label_b: str | None, + report: bool, +) -> None: + """Compare two quantizations of the same model. + + MODEL_A and MODEL_B are model specs — HuggingFace repos, Ollama tags, + or local GGUF paths. The backend is auto-detected from the identifier, + or you can use an explicit prefix (ollama:, mlx:, gguf:, vllm-mlx:). + + \b + Examples: + # Two MLX quants + infer-check compare \\ + mlx-community/Llama-3.1-8B-Instruct-4bit \\ + mlx-community/Llama-3.1-8B-Instruct-8bit + + # MLX native vs Ollama GGUF + infer-check compare \\ + mlx-community/Llama-3.1-8B-Instruct-4bit \\ + ollama:llama3.1:8b-instruct-q4_K_M + + # Bartowski GGUF vs Unsloth GGUF (both via Ollama) + infer-check compare \\ + ollama:bartowski/Llama-3.1-8B-Instruct-GGUF \\ + ollama:unsloth/Llama-3.1-8B-Instruct-GGUF + """ + # ── Resolve both model specs ───────────────────────────────────── + from infer_check.resolve import resolve_model + from infer_check.runner import TestRunner + from infer_check.suites.loader import load_suite + + resolved_a = resolve_model(model_a, base_url=base_url, label=label_a) + resolved_b = resolve_model(model_b, base_url=base_url, label=label_b) + + console.print( + f"[bold cyan]compare[/bold cyan] " + f"A={resolved_a.label} ({resolved_a.backend}) " + f"vs B={resolved_b.label} ({resolved_b.backend})" + ) + + prompt_list = load_suite(_resolve_prompts(prompts)) + console.print(f" prompts: {len(prompt_list)} from '{prompts}'") + + # ── Build backends ─────────────────────────────────────────────── + from infer_check.backends.base import BackendConfig, get_backend + + config_a = BackendConfig( + backend_type=resolved_a.backend, + model_id=resolved_a.model_id, + quantization=resolved_a.label, + base_url=resolved_a.base_url, + extra={"chat": False}, + ) + config_b = BackendConfig( + backend_type=resolved_b.backend, + model_id=resolved_b.model_id, + quantization=resolved_b.label, + base_url=resolved_b.base_url, + extra={"chat": False}, + ) + backend_a = get_backend(config_a) + backend_b = get_backend(config_b) + + # ── Run comparison ─────────────────────────────────────────────── + runner = TestRunner() + compare_result = asyncio.run( + runner.compare( + backend_a=backend_a, + backend_b=backend_b, + prompts=prompt_list, + label_a=resolved_a.label, + label_b=resolved_b.label, + ) + ) + + # ── Persist results ────────────────────────────────────────────── + output.mkdir(parents=True, exist_ok=True) + + from infer_check.utils import sanitize_filename + + safe_a = sanitize_filename(resolved_a.label) + safe_b = sanitize_filename(resolved_b.label) + out_path = output / f"compare_{safe_a}_vs_{safe_b}.json" + compare_result.save(out_path) + console.print(f"[green]Results saved to {out_path}[/green]") + + # ── Summary table ──────────────────────────────────────────────── + table = Table( + title=f"Compare: {resolved_a.label} vs {resolved_b.label}", + show_header=True, + header_style="bold magenta", + ) + table.add_column("metric", style="cyan") + table.add_column("value", justify="right") + + n = len(compare_result.comparisons) + severities = {"identical": 0, "minor": 0, "moderate": 0, "severe": 0} + for c in compare_result.comparisons: + sev = c.metadata.get("severity", "unknown") if hasattr(c, "metadata") else "unknown" + if sev in severities: + severities[sev] += 1 + + table.add_row("prompts", str(n)) + table.add_row( + "flip rate", + f"[{'red' if compare_result.flip_rate > 0.1 else 'green'}]{compare_result.flip_rate:.1%}[/]", + ) + if compare_result.mean_kl_divergence is not None: + table.add_row("mean KL divergence", f"{compare_result.mean_kl_divergence:.6f}") + table.add_row("mean text similarity", f"{compare_result.mean_text_similarity:.4f}") + table.add_row( + "identical / minor / moderate / severe", + f"{severities['identical']} / {severities['minor']} / " + f"{severities['moderate']} / [red]{severities['severe']}[/red]", + ) + + console.print(table) + + # ── Per-category breakdown ─────────────────────────────────────── + if compare_result.per_category_stats: + cat_table = Table( + title="Per-Category Breakdown", + show_header=True, + header_style="bold magenta", + ) + cat_table.add_column("category", style="cyan") + cat_table.add_column("prompts", justify="right") + cat_table.add_column("flip rate", justify="right") + cat_table.add_column("mean similarity", justify="right") + + for cat, stats in sorted(compare_result.per_category_stats.items()): + cat_table.add_row( + cat, + str(stats.get("count", 0)), + f"{stats.get('flip_rate', 0.0):.1%}", + f"{stats.get('mean_similarity', 0.0):.4f}", + ) + + console.print(cat_table) + + # ── Flipped prompts detail ─────────────────────────────────────── + flipped = [c for c in compare_result.comparisons if c.metadata.get("flipped", False)] + if flipped: + flip_table = Table( + title=f"Flipped Prompts ({len(flipped)})", + show_header=True, + header_style="bold magenta", + ) + flip_table.add_column("prompt", style="dim", max_width=50, no_wrap=True) + flip_table.add_column("category", style="cyan") + flip_table.add_column("strategy", style="dim") + flip_table.add_column(f"{resolved_a.label}", max_width=30, no_wrap=True) + flip_table.add_column(f"{resolved_b.label}", max_width=30, no_wrap=True) + flip_table.add_column("similarity", justify="right") + + for c in flipped: + prompt_text = c.baseline.text if hasattr(c.baseline, "text") else c.baseline.prompt_id + # Truncate long prompt text for display. + if len(prompt_text) > 47: + prompt_text = prompt_text[:47] + "..." + + ans_a = c.metadata.get("answer_a", "?") + ans_b = c.metadata.get("answer_b", "?") + # Truncate long answers. + if len(str(ans_a)) > 27: + ans_a = str(ans_a)[:27] + "..." + if len(str(ans_b)) > 27: + ans_b = str(ans_b)[:27] + "..." + + flip_table.add_row( + prompt_text, + c.metadata.get("category", "?"), + c.metadata.get("extraction_strategy", "?"), + f"[green]{ans_a}[/green]", + f"[red]{ans_b}[/red]", + f"{c.text_similarity:.3f}", + ) + + console.print(flip_table) + + # ── Report generation ─────────────────────────────────────────── + if report: + from infer_check.reporting.html import generate_report + + report_path = output / f"report_{safe_a}_vs_{safe_b}.html" + generate_report(output, report_path) + console.print(f"[green]HTML report generated at {report_path}[/green]") + elif n > 0 and not flipped: + console.print("[bold green]No answer flips detected.[/bold green]") + + # --------------------------------------------------------------------------- # diff # --------------------------------------------------------------------------- @@ -259,16 +494,12 @@ def diff( from infer_check.suites.loader import load_suite backend_names = [b.strip() for b in backends.split(",") if b.strip()] - url_list: list[str | None] = ( - [u.strip() for u in base_urls.split(",")] if base_urls else [None] * len(backend_names) - ) + url_list: list[str | None] = [u.strip() for u in base_urls.split(",")] if base_urls else [None] * len(backend_names) # Pad url_list if shorter than backend_names while len(url_list) < len(backend_names): url_list.append(None) - console.print( - f"[bold cyan]diff[/bold cyan] model={model} backends={backend_names} quant={quant}" - ) + console.print(f"[bold cyan]diff[/bold cyan] model={model} backends={backend_names} quant={quant}") prompt_list = load_suite(_resolve_prompts(prompts)) @@ -336,7 +567,7 @@ def diff( @main.command() @click.option("--model", required=True, help="Model ID or HuggingFace path.") -@click.option("--backend", default="mlx-lm", show_default=True, help="Backend type.") +@click.option("--backend", default=None, help="Backend type (auto-detected if omitted).") @click.option( "--prompts", required=True, @@ -358,33 +589,31 @@ def diff( @click.option("--base-url", default=None, help="Base URL for HTTP backends.") def stress( model: str, - backend: str, + backend: str | None, prompts: str, output: Path, concurrency: str, base_url: str | None, ) -> None: """Stress-test a backend with varying concurrency levels.""" - from infer_check.backends.base import BackendConfig, get_backend + from infer_check.backends.base import get_backend_for_model from infer_check.runner import TestRunner from infer_check.suites.loader import load_suite concurrency_levels = [int(c.strip()) for c in concurrency.split(",") if c.strip()] + backend_instance = get_backend_for_model( + model_str=model, + backend_type=backend, + base_url=base_url, + ) + console.print( - f"[bold cyan]stress[/bold cyan] model={model} backend={backend} " - f"concurrency={concurrency_levels}" + f"[bold cyan]stress[/bold cyan] model={model} backend={backend_instance.name} concurrency={concurrency_levels}" ) prompt_list = load_suite(_resolve_prompts(prompts)) - config = BackendConfig( - backend_type=backend, # type: ignore[arg-type] - model_id=model, - base_url=base_url, - ) - backend_instance = get_backend(config) - runner = TestRunner() stress_results = asyncio.run( runner.stress( @@ -427,7 +656,7 @@ def stress( @main.command() @click.option("--model", required=True, help="Model ID or HuggingFace path.") -@click.option("--backend", default="mlx-lm", show_default=True, help="Backend type.") +@click.option("--backend", default=None, help="Backend type (auto-detected if omitted).") @click.option( "--prompts", required=True, @@ -444,27 +673,26 @@ def stress( @click.option("--base-url", default=None, help="Base URL for HTTP backends.") def determinism( model: str, - backend: str, + backend: str | None, prompts: str, output: Path, runs: int, base_url: str | None, ) -> None: """Test whether a backend produces identical outputs across repeated runs at temperature=0.""" - from infer_check.backends.base import BackendConfig, get_backend + from infer_check.backends.base import get_backend_for_model from infer_check.runner import TestRunner from infer_check.suites.loader import load_suite - console.print(f"[bold cyan]determinism[/bold cyan] model={model} backend={backend} runs={runs}") - - prompt_list = load_suite(_resolve_prompts(prompts)) - - config = BackendConfig( - backend_type=backend, # type: ignore[arg-type] - model_id=model, + backend_instance = get_backend_for_model( + model_str=model, + backend_type=backend, base_url=base_url, ) - backend_instance = get_backend(config) + + console.print(f"[bold cyan]determinism[/bold cyan] model={model} backend={backend_instance.name} runs={runs}") + + prompt_list = load_suite(_resolve_prompts(prompts)) runner = TestRunner() det_results = asyncio.run( @@ -507,8 +735,7 @@ def determinism( if det_results: overall = sum(r.determinism_score for r in det_results) / len(det_results) console.print( - f"\n[bold]Overall determinism score:[/bold] " - f"[{'green' if overall == 1.0 else 'yellow'}]{overall:.2%}[/]" + f"\n[bold]Overall determinism score:[/bold] [{'green' if overall == 1.0 else 'yellow'}]{overall:.2%}[/]" ) diff --git a/src/infer_check/prompt_suites/quant-sensitive.jsonl b/src/infer_check/prompt_suites/quant-sensitive.jsonl new file mode 100644 index 0000000..332a1a1 --- /dev/null +++ b/src/infer_check/prompt_suites/quant-sensitive.jsonl @@ -0,0 +1,20 @@ +{"text": "Calculate 123456 * 987654. Show the step-by-step multiplication.", "category": "multi_digit_arithmetic", "max_tokens": 1024} +{"text": "What is the 10th root of 2? Provide the result to 10 decimal places.", "category": "precision_numerics", "max_tokens": 256} +{"text": "Determine if 9999999999999997 is a prime number. Explain your reasoning in detail.", "category": "large_number_reasoning", "max_tokens": 1024} +{"text": "A farmer has 17 sheep. All but 9 die. How many are left? Explain the logic step-by-step.", "category": "logical_puzzle", "max_tokens": 256} +{"text": "If a train leaves station A at 60 mph and another leaves station B at 90 mph, and they are 300 miles apart, when and where do they meet? Provide a detailed derivation.", "category": "long_chain_of_thought", "max_tokens": 512} +{"text": "Write a valid YAML configuration for a Kubernetes Deployment with a custom health check and resource limits. Ensure the indentation is perfect.", "category": "precise_syntax", "max_tokens": 512} +{"text": "Generate a complex nested JSON object with at least 5 levels of nesting, containing arrays of objects with mixed types. Ensure it's valid JSON.", "category": "precise_syntax", "max_tokens": 512} +{"text": "Translate this SQL query into a Python list comprehension: SELECT name FROM users WHERE age > 21 AND city = 'New York' ORDER BY name LIMIT 10", "category": "code_translation", "max_tokens": 256} +{"text": "Explain the difference between a shallow copy and a deep copy in Python with code examples showing the internal memory addresses using id().", "category": "long_chain_of_thought", "max_tokens": 1024} +{"text": "Solve for x: log2(x) + log2(x-2) = 3. Show every step of the algebraic manipulation.", "category": "algebraic_reasoning", "max_tokens": 512} +{"text": "What is the 50th Fibonacci number? Calculate it precisely without using scientific notation.", "category": "large_number_reasoning", "max_tokens": 512} +{"text": "Describe the steps of the SHA-256 hashing algorithm at a high level, but include the specific constants used in the initial hash values.", "category": "precision_numerics", "max_tokens": 1024} +{"text": "Write a Rust function that uses unsafe code to swap two integers using pointers. Explain why it is unsafe.", "category": "precise_syntax", "max_tokens": 512} +{"text": "Create a strictly formatted CSV list of the first 20 prime numbers, separated by semicolons, with each value enclosed in double quotes.", "category": "precise_syntax", "max_tokens": 256} +{"text": "If a clock shows 3:15, what is the exact angle between the hour and minute hands? Provide the derivation.", "category": "logical_puzzle", "max_tokens": 256} +{"text": "Compare the time complexity of QuickSort and MergeSort in the worst case, providing mathematical proofs for the Big O notation of each.", "category": "long_chain_of_thought", "max_tokens": 1024} +{"text": "Write a regular expression that matches valid email addresses according to RFC 5322, and explain each part of the regex.", "category": "precise_syntax", "max_tokens": 512} +{"text": "What is the exact value of e (Euler's number) to 30 decimal places?", "category": "precision_numerics", "max_tokens": 256} +{"text": "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Explain your reasoning.", "category": "logical_puzzle", "max_tokens": 256} +{"text": "Analyze the potential for race conditions in a multithreaded Python program that increments a global counter without locks. Provide a code example and explain how to fix it using threading.Lock.", "category": "long_chain_of_thought", "max_tokens": 1024} diff --git a/src/infer_check/reporting/html.py b/src/infer_check/reporting/html.py index 1af7c76..c662157 100644 --- a/src/infer_check/reporting/html.py +++ b/src/infer_check/reporting/html.py @@ -16,6 +16,7 @@ from jinja2 import Environment, Undefined from infer_check.types import ( + CompareResult, ComparisonResult, DeterminismResult, StressResult, @@ -318,6 +319,7 @@