Skip to content

Document QKAN solver selection (benchmark of exact vs flash vs cutn solver + pz vs real ansatz) #8

@Jim137

Description

@Jim137

Motivation

We currently expose multiple solver backends for QKAN (exact, flash, cutn) and two ansatz variants (pz, real). In practice, solver choice can change training step time and peak memory by large factors, and the “best” option depends heavily on model scale and workload.

This issue proposes a clear Solver Guide section to the README/docs using the benchmarks below as evidence, so users can choose the right backend without trial-and-error.

Solver Guiding

Case Device Recommended solver Why Notes
Small models, CPU runs, debugging, or you want a trusted baseline CPU (or GPU) exact (default) Simple + “reference” behavior First run may include one-time init overhead—do a warmup step before timing.
Most training workloads (medium → large models) / inference GPU flash Best overall speed / memory tradeoff in these benchmarks Good first choice for practical GPU training.
Extremely large / memory-bound runs (near OOM, very large layers/batches) GPU cutn Best scaling and peak-memory reduction in the extreme benchmark Can be slower than flash on mid-size problems; use when size/memory dominates.

Ansatz choice

  • Default: pz — most reliable quality across tasks.
  • real can be faster/smaller, but may hurt accuracy/convergence on some workloads—only use if you validate it on your task.
  • rpz (reduced pz encoding) can be viewed as a compromise method using fewer gates to achieve the same mathematical structure as pz encoding. However, it increases the data encoding gate dimensions to five, which could reduce computational efficiency compared to the normal pz without preact_trainable. And we would suggest users turn on preact_trainable to further improve the accuracy of rpz ansatz. (I did not benchmark rpz ansatz here, but I will conduct a more comprehensive benchmark across all ansatzes in the examples section in the future.)

Benchmark environment

  • CPU: AMD Ryzen 9 9950X3D
  • GPU: NVIDIA GeForce RTX 5090
  • PyTorch: 2.9.1+cu128
  • CUDA: 12.8

The following is updated by PR #11 , which implemented a better FlashQKAN.
We also added the JHCG Net baseline, which replaces the QKAN in HQKAN with efficient-KAN (with B-Splines).

Key findings

  1. Small 1D function fitting (Benchmark 1)
  • flash_pz is ~11.5× faster on forward and ~5.4× faster per training step vs exact_pz, with identical test loss (0.0947).
  • real ansatz is not quality-preserving here: test loss is much worse (~0.138 vs 0.0947) despite speedups.
  • cutn_pz is only ~1.03× faster per step than exact_pz at this tiny scale.
  1. CIFAR-100 HQKAN-44 (Benchmark 2)
  • pz variants maintain accuracy (Top-1 ~34.6–35.1%), while real variants collapse (Top-1 ~9.8–10.9%).
  • Among QKAN pz solvers, cutn_pz is fastest here (104.8s vs 118.9s for exact_pz), with flash_pz close (107.6s).
  • Peak memory improves substantially vs exact_pz (862.7 MiB): flash_pz (484.9 MiB) and cutn_pz (457.9 MiB).
  1. GPT-2 (batch=1) TinyShakespeare (Benchmark 3)
  • Fastest QKAN: flash_pz (14.681 ms/step, 29.4s total; 1.69× faster than exact_pz), with lower peak memory (~697 MiB).
  • Best QKAN loss is achieved by exact_real (2.351), but real is only recommended for transformers after validation.
  • QKAN variants use far less peak memory than MLP baselines (e.g., flash_pz 697 MiB vs mlp 1698 MiB).
  1. GPT-2 (batch=10) WebText (Benchmark 4)
  • Fastest QKAN: flash_pz / flash_real at 120.0 s total (1.55× faster than exact_pz), with the lowest QKAN peak memory (~11.64 GiB).
  • exact_real / cutn_real are also fast (126–127 s) with slightly higher peak (11.81 GiB).
  • MLP achieves the best loss (6.230) but uses 1.84× more parameters (123.69M vs 67.23M) and much higher peak memory (15.16 GiB). triton_mlp improves speed (150.6 s) and reduces peak vs MLP, but is still slower than QKAN flash.
  1. Extreme synthetic QKAN (Benchmark 5)
  • flash_pz is dominant for speed with good convergence: 20.31× faster training steps than exact_pz with identical train loss (0.8585).
  • cutn_pz trades speed for memory: only 2.15× faster than exact_pz, but cuts peak memory to 1153 MiB (~55% lower than exact_pz).
  • real variants are extremely fast but do not converge (train loss ~31) — not usable in this setting.

Benchmark 1: README Function Fitting

Model: QKAN([1, 1], reps=3) with trainable pre/post activations
Data: 1000 train / 1000 test, 1D, function sin(20x)/(20x)
Training: Adam lr=1e-3, 100 steps

Variant Ansatz Params Init Forward Fwd vs exact_pz Train Step Step vs exact_pz 100 Steps Peak Mem Avg Mem Test Loss
exact_pz pz 18 9.0 ms 0.829 ms 1.00x 3.328 ms 1.00x 332.8 ms 16.6 MiB 16.3 MiB 0.0947
exact_real real 13 0.5 ms 0.572 ms 1.45x 2.528 ms 1.32x 252.8 ms 16.4 MiB 16.3 MiB 0.1381
flash_pz pz 18 0.4 ms 0.072 ms 11.46x 0.619 ms 5.38x 61.9 ms 16.5 MiB 16.3 MiB 0.0947
flash_real real 13 0.5 ms 0.080 ms 10.39x 0.668 ms 4.98x 66.8 ms 16.4 MiB 16.3 MiB 0.1386
cutn_pz pz 18 0.5 ms 0.703 ms 1.18x 3.245 ms 1.03x 324.5 ms 16.6 MiB 16.3 MiB 0.0947
cutn_real real 13 0.5 ms 0.484 ms 1.71x 2.264 ms 1.47x 226.4 ms 16.4 MiB 16.3 MiB 0.1381
kan bspline 10 1.0 ms 0.195 ms 4.25x 0.595 ms 5.59x 59.5 ms 16.4 MiB 16.3 MiB 0.0697

Benchmark 2: HQKAN CIFAR-100

HQKAN-44: CNet -> Linear(256, 32) -> QKAN([32, 28]) -> Linear(28, 100) (notebook)
JHCG: CNet -> Linear(256, 32) -> KAN([32, 28]) -> Linear(28, 100)
Data: CIFAR-100, batch size 1000
Training: Adam lr=1e-3, 50 epochs

Variant Ansatz Params Init Forward Fwd vs exact_pz Train Step Step vs exact_pz 50ep Time Peak Mem Avg Mem Test Loss Top-1 Top-5
exact_pz pz 83572 1.3 ms 4.134 ms 1.00x 47.564 ms 1.00x 118.9 s 862.7 MiB 42.6 MiB 2.573 35.1% 66.5%
exact_real real 79092 0.5 ms 1.437 ms 2.88x 44.662 ms 1.06x 111.7 s 373.0 MiB 42.4 MiB 3.862 10.9% 33.0%
flash_pz pz 83572 0.5 ms 0.920 ms 4.49x 43.024 ms 1.11x 107.6 s 484.9 MiB 42.6 MiB 2.562 34.8% 66.4%
flash_real real 79092 0.4 ms 0.909 ms 4.55x 41.933 ms 1.13x 104.8 s 385.2 MiB 42.4 MiB 4.026 9.8% 30.4%
cutn_pz pz 83572 0.4 ms 2.832 ms 1.46x 41.909 ms 1.13x 104.8 s 457.9 MiB 42.6 MiB 2.574 34.6% 66.0%
cutn_real real 79092 0.4 ms 1.403 ms 2.95x 48.621 ms 0.98x 121.6 s 373.0 MiB 42.4 MiB 3.868 10.8% 32.8%
JHCG bspline 76404 17.4 ms 0.968 ms 4.27x 39.762 ms 1.20x 99.4 s 373.1 MiB 42.5 MiB 2.645 33.7% 64.4%

Benchmark 3: HQKANsformer vs MLP GPT-2

HQKANsformer: GPT-2 (12L, 12H, 768E) with Linear(768,10) -> QKAN([10,10], reps=1) -> Linear(10,768) replacing MLP
MLP GPT-2: Standard GPT-2 with Linear(768,3072) -> GELU -> Linear(3072,768) MLP
Triton MLP GPT-2: Same as MLP but with fused Triton bias+GELU kernel (avoids intermediate materialization)
Data: TinyShakespeare (char-level), batch size 1, block size 1024
Training: AdamW lr=0.0003, betas=(0.9,0.95), weight_decay=0.1, grad_clip=1.0, 2000 iters

Variant Ansatz Params Init Forward Fwd vs exact_pz Train Step Step vs exact_pz 2000 Iters Peak Mem Avg Mem Final Loss
exact_pz pz 28.64M 149.8 ms 9.101 ms 1.00x 24.799 ms 1.00x 49.6 s 810.8 MiB 492.9 MiB 2.398
exact_real real 28.64M 215.8 ms 7.175 ms 1.27x 20.721 ms 1.20x 41.4 s 717.1 MiB 497.0 MiB 2.351
flash_pz pz 28.64M 216.6 ms 4.520 ms 2.01x 14.681 ms 1.69x 29.4 s 697.2 MiB 491.2 MiB 2.399
flash_real real 28.64M 213.4 ms 4.622 ms 1.97x 15.081 ms 1.64x 30.2 s 698.4 MiB 494.4 MiB 2.377
cutn_pz pz 28.64M 216.1 ms 10.049 ms 0.91x 27.717 ms 0.89x 55.4 s 784.6 MiB 493.7 MiB 2.402
cutn_real real 28.64M 152.1 ms 7.104 ms 1.28x 20.233 ms 1.23x 40.5 s 715.4 MiB 495.8 MiB 2.375
JHCG bspline 28.64M 157.9 ms 5.267 ms 1.73x 16.802 ms 1.48x 33.6 s 729.6 MiB 493.4 MiB 2.494
triton_mlp mlp 85.11M 411.1 ms 6.737 ms 1.35x 18.086 ms 1.37x 36.2 s 1488.3 MiB 1040.4 MiB 2.488
plain_mlp mlp 85.11M 496.5 ms 7.042 ms 1.29x 23.850 ms 1.04x 47.7 s 1697.7 MiB 1366.3 MiB 2.528

Benchmark 4: HQKANsformer vs MLP GPT-2 on WebText (batch=10)

HQKANsformer: GPT-2 (12L, 12H, 768E) with Linear(768,10) -> QKAN([10,10], reps=1) -> Linear(10,768) replacing MLP
MLP GPT-2: Standard GPT-2 with Linear(768,3072) -> GELU -> Linear(3072,768) MLP
Triton MLP GPT-2: Same as MLP but with fused Triton bias+GELU kernel (avoids intermediate materialization)
Data: WebText (GPT-2 tokenizer, vocab_size=50304), batch size 10, block size 1024
Training: AdamW lr=0.0003, betas=(0.9,0.95), weight_decay=0.1, grad_clip=1.0, 1000 iters

Variant Ansatz Params Init Forward Fwd vs exact_pz Train Step Step vs exact_pz 1000 Iters Peak Mem Avg Mem Final Loss
exact_pz pz 67.23M 515.6 ms 66.981 ms 1.00x 185.643 ms 1.00x 185.6 s 12830.0 MiB 3058.1 MiB 6.491
exact_real real 67.22M 520.8 ms 42.139 ms 1.59x 126.959 ms 1.46x 127.0 s 11811.3 MiB 3058.9 MiB 6.407
flash_pz pz 67.23M 529.3 ms 39.557 ms 1.69x 120.033 ms 1.55x 120.0 s 11646.3 MiB 3047.8 MiB 6.428
flash_real real 67.22M 567.0 ms 39.653 ms 1.69x 119.964 ms 1.55x 120.0 s 11644.7 MiB 3046.0 MiB 6.417
cutn_pz pz 67.23M 536.0 ms 57.676 ms 1.16x 175.005 ms 1.06x 175.0 s 12443.1 MiB 3045.9 MiB 6.394
cutn_real real 67.22M 513.9 ms 41.874 ms 1.60x 126.224 ms 1.47x 126.2 s 11807.0 MiB 3050.1 MiB 6.409
JHCG bspline 67.23M 511.8 ms 40.067 ms 1.67x 122.693 ms 1.51x 122.7 s 11994.0 MiB 3050.3 MiB 6.603
triton_mlp mlp 123.69M 777.7 ms 64.211 ms 1.04x 150.647 ms 1.23x 150.6 s 12781.0 MiB 3588.1 MiB 6.332
plain_mlp mlp 123.69M 828.0 ms 64.256 ms 1.04x 183.333 ms 1.01x 183.3 s 15159.6 MiB 3914.4 MiB 6.230

Benchmark 5: Extreme Synthetic QKAN

Model: QKAN([100, 100], reps=3)
Data: Random 1000x100 input/output
Training: Adam lr=1e-3, 50 steps

Variant Ansatz Params Init Forward Fwd vs exact_pz Train Step Step vs exact_pz 50 Steps Peak Mem Avg Mem Train Loss
exact_pz pz 180000 0.5 ms 36.709 ms 1.00x 78.111 ms 1.00x 3905.6 ms 2569.8 MiB 41.0 MiB 0.8585
exact_real real 130000 0.5 ms 3.476 ms 10.56x 8.568 ms 9.12x 428.4 ms 327.5 MiB 39.8 MiB 31.1815
flash_pz pz 180000 0.5 ms 0.635 ms 57.77x 3.845 ms 20.31x 192.3 ms 1955.5 MiB 41.0 MiB 0.8585
flash_real real 130000 0.5 ms 0.419 ms 87.57x 1.563 ms 49.97x 78.2 ms 842.4 MiB 39.8 MiB 31.1592
cutn_pz pz 180000 0.4 ms 16.139 ms 2.27x 36.334 ms 2.15x 1816.7 ms 1153.0 MiB 41.0 MiB 0.8585
cutn_real real 130000 0.4 ms 3.785 ms 9.70x 8.213 ms 9.51x 410.6 ms 327.5 MiB 39.8 MiB 31.1815

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions