Skip to content

tnn1t1s/batch_invariant_ops

 
 

Repository files navigation

Batch Invariant Operations for Deterministic LLM Inference

This project builds on "Defeating Nondeterminism in LLM Inference" by adding reproducible tests, documentation, and tooling that make the batch-invariant operators easier to evaluate and reuse.

Relationship to the Original Repository

This fork builds directly on the Thinking Machines Lab batch_invariant_ops repository, which introduced the core concept and provided a clear demonstration of batch-size–induced nondeterminism. Their work shows how standard operators can produce different results under different batching.

The goal of this fork is not to replace or diverge from that work, but to add rigor and breadth:

  • Library structure: Exposing the operators through a package layout (src/) with importable functions (matmul_persistent, set_batch_invariant_mode), so they can be reused and tested in broader contexts.

  • Unit tests: Introducing a reproducible test harness (pytest), moving beyond a single demo script toward systematic verification.

  • Documentation: Expanding the README and adding docs to guide reproducibility and integration.

  • Research agenda: Issue #2 proposes a set of toy-model benchmarks (sequence reversal, running sum, parity, etc.) to isolate batch-size–induced nondeterminism. These are designed to complement the original transformer-based example with falsifiable, operator-level tests.

Project Goals

  1. Validate batch-invariance claims across diverse hardware configurations
  2. Benchmark determinism vs performance trade-offs for various model sizes
  3. Develop production-ready batch-invariant kernels for vLLM
  4. Create reproducible testing framework for LLM determinism research
  5. Explore optimization strategies to reduce the ~60% performance overhead

Overview

This library primarily leverages torch.Library to sub out existing PyTorch kernels with "batch-invariant" ones. This allows many existing PyTorch models to use the batch-invariant ops with low overhead and non-intrusive code changes.

Installation

# Clone the repository
git clone https://github.com/tnn1t1s/batch_invariant_ops.git
cd batch_invariant_ops

# Set up virtual environment and install
python3 -m venv .venv
source .venv/bin/activate
pip install -e .

# Optional: Clone vLLM for integration testing
git clone https://github.com/vllm-project/vllm.git

Quick Start

import torch
from batch_invariant_ops import set_batch_invariant_mode

# Enable batch-invariant mode
with set_batch_invariant_mode():
    # Your inference code here
    model = YourModel()
    output = model(input_tensor)

Testing Batch-Invariance

The following example shows how batch size can affect results in standard PyTorch:

import torch
from batch_invariant_ops import set_batch_invariant_mode

torch.set_default_device("cuda")

def test_batch_invariance():
    B, D = 2048, 4096
    a = torch.linspace(-100, 100, B * D).reshape(B, D)
    b = torch.linspace(-100, 100, D * D).reshape(D, D)
    
    # Method 1: Matrix-vector multiplication (batch size 1)
    out1 = torch.mm(a[:1], b)
    
    # Method 2: Matrix-matrix multiplication, then slice (full batch)
    out2 = torch.mm(a, b)[:1]
    
    # Check if results are identical
    diff = (out1 - out2).abs().max()
    print(f"Difference: {diff.item()}")
    return diff.item() == 0

# Test with standard PyTorch (likely to show differences)
print("Standard PyTorch:")
with set_batch_invariant_mode(False):
    deterministic = test_batch_invariance()
    print(f"Deterministic: {deterministic}")

# Test with batch-invariant operations
print("\nBatch-Invariant Mode:")
with set_batch_invariant_mode(True):
    deterministic = test_batch_invariance()
    print(f"Deterministic: {deterministic}")

Testing & Validation

Run Tests

# Run unit tests (fast, ~1.25s)
pytest tests/unit/ -v

# Run integration tests
pytest tests/integration/ -v

# Run all tests with coverage
pytest tests/ --cov=batch_invariant_ops

# Run specific test suites
pytest tests/integration/test_progressive_complexity.py -v

Test Results

  • 29/29 unit tests passing - All kernels validated
  • Perfect batch invariance - 0.00e+00 difference across batch sizes
  • No error accumulation - Maintained through 4+ transformer layers
  • Acceptable performance - ~2x overhead for small batches, improves with scale

Integration Testing

Our progressive complexity tests validate batch invariance across:

  • Single MatMul operations
  • Sequential MatMul chains
  • Multi-head attention mechanisms
  • Full transformer blocks
  • Deep networks (4+ layers)

Deterministic Inference in vLLM

deterministic_vllm_inference.py shows an proof of concept of validating that vLLM can be made deterministic with a minor upstream PR to use this library. Without the upstream PR, we see that out of 1000 random length 100 completions we see 18 unique samples. After the upstream PR, there is only one unique sample.

Supported Operations

Matrix Operations

  • torch.mm() - Matrix multiplication
  • torch.addmm() - Matrix multiplication with bias addition

Activation Functions

  • torch.log_softmax() - Log-softmax activation

Reduction Operations

  • torch.mean() - Mean computation along specified dimensions

Project Structure

batch_invariant_ops/
├── src/batch_invariant_ops/      # Core implementation
│   ├── __init__.py               # Public API
│   └── kernels.py                # Triton kernels
├── tests/
│   ├── fixtures/                 # Shared test models
│   ├── unit/                     # Fast kernel tests (~1.25s)
│   └── integration/              # End-to-end tests
├── docs/                         # Documentation
│   ├── TESTING.md               # Testing strategy & results
│   └── defeating-nondeterminism.md  # Original paper
└── CLAUDE.md                     # Development guide

Toy Model Experiments (Issue #2)

To demonstrate that batch nondeterminism is fundamental to transformer architectures (not just a quirk of large models), we've implemented 5 simple toy tasks in experiments/toy_models/:

The Five Tasks

  1. Sequence Reversal: Reverse input sequence → Tests attention patterns
  2. Running Sum: Cumulative sum → Shows error accumulation
  3. Modular Arithmetic: Input mod 5 → Non-linear transformations amplify differences
  4. XOR Parity: Running XOR → Binary decisions make errors obvious
  5. Copy-Shift: Rotate sequence left → Pattern matching sensitive to attention

Quick Demo

# Train a model on sequence reversal
cd experiments/toy_models
python train.py --task reversal --epochs 100

# Test batch variance (same input, different batch sizes)
python batch_variance_test.py --task reversal

# Test with batch-invariant mode
python batch_variance_test.py --task reversal --test-invariant

Key Results

  • Even tiny 2-layer transformers exhibit batch nondeterminism
  • Different batch sizes produce different outputs for the SAME input
  • Batch-invariant operations completely eliminate the variance
  • See experiments/toy_models/README.md for detailed results

Benchmark Metrics

Determinism Metrics

  • Exact Match Rate: Percentage of identical outputs
  • Token Divergence Point: First token position where outputs differ
  • Numerical Drift: Maximum floating-point difference per layer
  • Entropy Score: Randomness measure across multiple runs

Performance Metrics

  • Throughput: Tokens/second
  • Latency: Time to first token (TTFT) and inter-token latency
  • Memory Usage: Peak VRAM consumption
  • Overhead: Percentage increase vs non-deterministic baseline

Contributing

We welcome contributions in the following areas:

  • Kernel optimizations to reduce overhead
  • Testing on additional GPU architectures
  • Support for quantized models (GPTQ, AWQ, GGUF)
  • Multi-GPU inference determinism
  • Integration with other inference engines (TGI, TensorRT-LLM)

Please see CONTRIBUTING.md for detailed guidelines.

Recent Progress

September 2024

  • ✅ Fixed Triton kernel compatibility issues
  • ✅ Achieved 29/29 unit tests passing with proper numerical tolerances
  • ✅ Created progressive complexity test suite for transformer models
  • ✅ Validated perfect batch invariance (0.00 difference) across all operations
  • ✅ Consolidated repository structure for better maintainability
  • ✅ Documented comprehensive testing strategy and results
  • ✅ Implemented 5 toy models demonstrating batch nondeterminism (Issue #2)
  • ✅ Added relationship to original Thinking Machines Lab repository

Citation

If you use this work in your research, please cite:

@software{batch_invariant_ops2024,
  title={Batch Invariant Operations for Deterministic LLM Inference},
  author={David J. Palaitis},
  year={2024},
  url={https://github.com/tnn1t1s/batch_invariant_ops}
}

Acknowledgments

  • Original research by Thinking Machines AI
  • vLLM team for the inference engine
  • Contributors and testers from the community

About

Research: Reproducible benchmarks for batch-invariant LLM inference across models & GPUs (A10, A100, H100)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.2%
  • Just 2.4%
  • Shell 0.4%