-
Notifications
You must be signed in to change notification settings - Fork 8
Document QKAN solver selection (benchmark of exact vs flash vs cutn solver + pz vs real ansatz) #8
Description
Motivation
We currently expose multiple solver backends for QKAN (exact, flash, cutn) and two ansatz variants (pz, real). In practice, solver choice can change training step time and peak memory by large factors, and the “best” option depends heavily on model scale and workload.
This issue proposes a clear Solver Guide section to the README/docs using the benchmarks below as evidence, so users can choose the right backend without trial-and-error.
Solver Guiding
| Case | Device | Recommended solver | Why | Notes |
|---|---|---|---|---|
| Small models, CPU runs, debugging, or you want a trusted baseline | CPU (or GPU) | exact (default) |
Simple + “reference” behavior | First run may include one-time init overhead—do a warmup step before timing. |
| Most training workloads (medium → large models) / inference | GPU | flash |
Best overall speed / memory tradeoff in these benchmarks | Good first choice for practical GPU training. |
| Extremely large / memory-bound runs (near OOM, very large layers/batches) | GPU | cutn |
Best scaling and peak-memory reduction in the extreme benchmark | Can be slower than flash on mid-size problems; use when size/memory dominates. |
Ansatz choice
- Default:
pz— most reliable quality across tasks. realcan be faster/smaller, but may hurt accuracy/convergence on some workloads—only use if you validate it on your task.rpz(reduced pz encoding) can be viewed as a compromise method using fewer gates to achieve the same mathematical structure as pz encoding. However, it increases the data encoding gate dimensions to five, which could reduce computational efficiency compared to the normal pz without preact_trainable. And we would suggest users turn on preact_trainable to further improve the accuracy ofrpzansatz. (I did not benchmarkrpzansatz here, but I will conduct a more comprehensive benchmark across all ansatzes in the examples section in the future.)
Benchmark environment
- CPU: AMD Ryzen 9 9950X3D
- GPU: NVIDIA GeForce RTX 5090
- PyTorch: 2.9.1+cu128
- CUDA: 12.8
The following is updated by PR #11 , which implemented a better FlashQKAN.
We also added the JHCG Net baseline, which replaces the QKAN in HQKAN with efficient-KAN (with B-Splines).
Key findings
- Small 1D function fitting (Benchmark 1)
flash_pzis ~11.5× faster on forward and ~5.4× faster per training step vsexact_pz, with identical test loss (0.0947).realansatz is not quality-preserving here: test loss is much worse (~0.138 vs 0.0947) despite speedups.cutn_pzis only ~1.03× faster per step thanexact_pzat this tiny scale.
- CIFAR-100 HQKAN-44 (Benchmark 2)
pzvariants maintain accuracy (Top-1 ~34.6–35.1%), whilerealvariants collapse (Top-1 ~9.8–10.9%).- Among QKAN
pzsolvers,cutn_pzis fastest here (104.8s vs 118.9s forexact_pz), withflash_pzclose (107.6s). - Peak memory improves substantially vs
exact_pz(862.7 MiB):flash_pz(484.9 MiB) andcutn_pz(457.9 MiB).
- GPT-2 (batch=1) TinyShakespeare (Benchmark 3)
- Fastest QKAN:
flash_pz(14.681 ms/step, 29.4s total; 1.69× faster thanexact_pz), with lower peak memory (~697 MiB). - Best QKAN loss is achieved by
exact_real(2.351), butrealis only recommended for transformers after validation. - QKAN variants use far less peak memory than MLP baselines (e.g.,
flash_pz697 MiB vsmlp1698 MiB).
- GPT-2 (batch=10) WebText (Benchmark 4)
- Fastest QKAN:
flash_pz/flash_realat 120.0 s total (1.55× faster thanexact_pz), with the lowest QKAN peak memory (~11.64 GiB). exact_real/cutn_realare also fast (126–127 s) with slightly higher peak (11.81 GiB).- MLP achieves the best loss (6.230) but uses 1.84× more parameters (123.69M vs 67.23M) and much higher peak memory (15.16 GiB).
triton_mlpimproves speed (150.6 s) and reduces peak vs MLP, but is still slower than QKANflash.
- Extreme synthetic QKAN (Benchmark 5)
flash_pzis dominant for speed with good convergence: 20.31× faster training steps thanexact_pzwith identical train loss (0.8585).cutn_pztrades speed for memory: only 2.15× faster thanexact_pz, but cuts peak memory to 1153 MiB (~55% lower thanexact_pz).realvariants are extremely fast but do not converge (train loss ~31) — not usable in this setting.
Benchmark 1: README Function Fitting
Model: QKAN([1, 1], reps=3) with trainable pre/post activations
Data: 1000 train / 1000 test, 1D, function sin(20x)/(20x)
Training: Adam lr=1e-3, 100 steps
| Variant | Ansatz | Params | Init | Forward | Fwd vs exact_pz | Train Step | Step vs exact_pz | 100 Steps | Peak Mem | Avg Mem | Test Loss |
|---|---|---|---|---|---|---|---|---|---|---|---|
| exact_pz | pz | 18 | 9.0 ms | 0.829 ms | 1.00x | 3.328 ms | 1.00x | 332.8 ms | 16.6 MiB | 16.3 MiB | 0.0947 |
| exact_real | real | 13 | 0.5 ms | 0.572 ms | 1.45x | 2.528 ms | 1.32x | 252.8 ms | 16.4 MiB | 16.3 MiB | 0.1381 |
| flash_pz | pz | 18 | 0.4 ms | 0.072 ms | 11.46x | 0.619 ms | 5.38x | 61.9 ms | 16.5 MiB | 16.3 MiB | 0.0947 |
| flash_real | real | 13 | 0.5 ms | 0.080 ms | 10.39x | 0.668 ms | 4.98x | 66.8 ms | 16.4 MiB | 16.3 MiB | 0.1386 |
| cutn_pz | pz | 18 | 0.5 ms | 0.703 ms | 1.18x | 3.245 ms | 1.03x | 324.5 ms | 16.6 MiB | 16.3 MiB | 0.0947 |
| cutn_real | real | 13 | 0.5 ms | 0.484 ms | 1.71x | 2.264 ms | 1.47x | 226.4 ms | 16.4 MiB | 16.3 MiB | 0.1381 |
| kan | bspline | 10 | 1.0 ms | 0.195 ms | 4.25x | 0.595 ms | 5.59x | 59.5 ms | 16.4 MiB | 16.3 MiB | 0.0697 |
Benchmark 2: HQKAN CIFAR-100
HQKAN-44: CNet -> Linear(256, 32) -> QKAN([32, 28]) -> Linear(28, 100) (notebook)
JHCG: CNet -> Linear(256, 32) -> KAN([32, 28]) -> Linear(28, 100)
Data: CIFAR-100, batch size 1000
Training: Adam lr=1e-3, 50 epochs
| Variant | Ansatz | Params | Init | Forward | Fwd vs exact_pz | Train Step | Step vs exact_pz | 50ep Time | Peak Mem | Avg Mem | Test Loss | Top-1 | Top-5 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| exact_pz | pz | 83572 | 1.3 ms | 4.134 ms | 1.00x | 47.564 ms | 1.00x | 118.9 s | 862.7 MiB | 42.6 MiB | 2.573 | 35.1% | 66.5% |
| exact_real | real | 79092 | 0.5 ms | 1.437 ms | 2.88x | 44.662 ms | 1.06x | 111.7 s | 373.0 MiB | 42.4 MiB | 3.862 | 10.9% | 33.0% |
| flash_pz | pz | 83572 | 0.5 ms | 0.920 ms | 4.49x | 43.024 ms | 1.11x | 107.6 s | 484.9 MiB | 42.6 MiB | 2.562 | 34.8% | 66.4% |
| flash_real | real | 79092 | 0.4 ms | 0.909 ms | 4.55x | 41.933 ms | 1.13x | 104.8 s | 385.2 MiB | 42.4 MiB | 4.026 | 9.8% | 30.4% |
| cutn_pz | pz | 83572 | 0.4 ms | 2.832 ms | 1.46x | 41.909 ms | 1.13x | 104.8 s | 457.9 MiB | 42.6 MiB | 2.574 | 34.6% | 66.0% |
| cutn_real | real | 79092 | 0.4 ms | 1.403 ms | 2.95x | 48.621 ms | 0.98x | 121.6 s | 373.0 MiB | 42.4 MiB | 3.868 | 10.8% | 32.8% |
| JHCG | bspline | 76404 | 17.4 ms | 0.968 ms | 4.27x | 39.762 ms | 1.20x | 99.4 s | 373.1 MiB | 42.5 MiB | 2.645 | 33.7% | 64.4% |
Benchmark 3: HQKANsformer vs MLP GPT-2
HQKANsformer: GPT-2 (12L, 12H, 768E) with Linear(768,10) -> QKAN([10,10], reps=1) -> Linear(10,768) replacing MLP
MLP GPT-2: Standard GPT-2 with Linear(768,3072) -> GELU -> Linear(3072,768) MLP
Triton MLP GPT-2: Same as MLP but with fused Triton bias+GELU kernel (avoids intermediate materialization)
Data: TinyShakespeare (char-level), batch size 1, block size 1024
Training: AdamW lr=0.0003, betas=(0.9,0.95), weight_decay=0.1, grad_clip=1.0, 2000 iters
| Variant | Ansatz | Params | Init | Forward | Fwd vs exact_pz | Train Step | Step vs exact_pz | 2000 Iters | Peak Mem | Avg Mem | Final Loss |
|---|---|---|---|---|---|---|---|---|---|---|---|
| exact_pz | pz | 28.64M | 149.8 ms | 9.101 ms | 1.00x | 24.799 ms | 1.00x | 49.6 s | 810.8 MiB | 492.9 MiB | 2.398 |
| exact_real | real | 28.64M | 215.8 ms | 7.175 ms | 1.27x | 20.721 ms | 1.20x | 41.4 s | 717.1 MiB | 497.0 MiB | 2.351 |
| flash_pz | pz | 28.64M | 216.6 ms | 4.520 ms | 2.01x | 14.681 ms | 1.69x | 29.4 s | 697.2 MiB | 491.2 MiB | 2.399 |
| flash_real | real | 28.64M | 213.4 ms | 4.622 ms | 1.97x | 15.081 ms | 1.64x | 30.2 s | 698.4 MiB | 494.4 MiB | 2.377 |
| cutn_pz | pz | 28.64M | 216.1 ms | 10.049 ms | 0.91x | 27.717 ms | 0.89x | 55.4 s | 784.6 MiB | 493.7 MiB | 2.402 |
| cutn_real | real | 28.64M | 152.1 ms | 7.104 ms | 1.28x | 20.233 ms | 1.23x | 40.5 s | 715.4 MiB | 495.8 MiB | 2.375 |
| JHCG | bspline | 28.64M | 157.9 ms | 5.267 ms | 1.73x | 16.802 ms | 1.48x | 33.6 s | 729.6 MiB | 493.4 MiB | 2.494 |
| triton_mlp | mlp | 85.11M | 411.1 ms | 6.737 ms | 1.35x | 18.086 ms | 1.37x | 36.2 s | 1488.3 MiB | 1040.4 MiB | 2.488 |
| plain_mlp | mlp | 85.11M | 496.5 ms | 7.042 ms | 1.29x | 23.850 ms | 1.04x | 47.7 s | 1697.7 MiB | 1366.3 MiB | 2.528 |
Benchmark 4: HQKANsformer vs MLP GPT-2 on WebText (batch=10)
HQKANsformer: GPT-2 (12L, 12H, 768E) with Linear(768,10) -> QKAN([10,10], reps=1) -> Linear(10,768) replacing MLP
MLP GPT-2: Standard GPT-2 with Linear(768,3072) -> GELU -> Linear(3072,768) MLP
Triton MLP GPT-2: Same as MLP but with fused Triton bias+GELU kernel (avoids intermediate materialization)
Data: WebText (GPT-2 tokenizer, vocab_size=50304), batch size 10, block size 1024
Training: AdamW lr=0.0003, betas=(0.9,0.95), weight_decay=0.1, grad_clip=1.0, 1000 iters
| Variant | Ansatz | Params | Init | Forward | Fwd vs exact_pz | Train Step | Step vs exact_pz | 1000 Iters | Peak Mem | Avg Mem | Final Loss |
|---|---|---|---|---|---|---|---|---|---|---|---|
| exact_pz | pz | 67.23M | 515.6 ms | 66.981 ms | 1.00x | 185.643 ms | 1.00x | 185.6 s | 12830.0 MiB | 3058.1 MiB | 6.491 |
| exact_real | real | 67.22M | 520.8 ms | 42.139 ms | 1.59x | 126.959 ms | 1.46x | 127.0 s | 11811.3 MiB | 3058.9 MiB | 6.407 |
| flash_pz | pz | 67.23M | 529.3 ms | 39.557 ms | 1.69x | 120.033 ms | 1.55x | 120.0 s | 11646.3 MiB | 3047.8 MiB | 6.428 |
| flash_real | real | 67.22M | 567.0 ms | 39.653 ms | 1.69x | 119.964 ms | 1.55x | 120.0 s | 11644.7 MiB | 3046.0 MiB | 6.417 |
| cutn_pz | pz | 67.23M | 536.0 ms | 57.676 ms | 1.16x | 175.005 ms | 1.06x | 175.0 s | 12443.1 MiB | 3045.9 MiB | 6.394 |
| cutn_real | real | 67.22M | 513.9 ms | 41.874 ms | 1.60x | 126.224 ms | 1.47x | 126.2 s | 11807.0 MiB | 3050.1 MiB | 6.409 |
| JHCG | bspline | 67.23M | 511.8 ms | 40.067 ms | 1.67x | 122.693 ms | 1.51x | 122.7 s | 11994.0 MiB | 3050.3 MiB | 6.603 |
| triton_mlp | mlp | 123.69M | 777.7 ms | 64.211 ms | 1.04x | 150.647 ms | 1.23x | 150.6 s | 12781.0 MiB | 3588.1 MiB | 6.332 |
| plain_mlp | mlp | 123.69M | 828.0 ms | 64.256 ms | 1.04x | 183.333 ms | 1.01x | 183.3 s | 15159.6 MiB | 3914.4 MiB | 6.230 |
Benchmark 5: Extreme Synthetic QKAN
Model: QKAN([100, 100], reps=3)
Data: Random 1000x100 input/output
Training: Adam lr=1e-3, 50 steps
| Variant | Ansatz | Params | Init | Forward | Fwd vs exact_pz | Train Step | Step vs exact_pz | 50 Steps | Peak Mem | Avg Mem | Train Loss |
|---|---|---|---|---|---|---|---|---|---|---|---|
| exact_pz | pz | 180000 | 0.5 ms | 36.709 ms | 1.00x | 78.111 ms | 1.00x | 3905.6 ms | 2569.8 MiB | 41.0 MiB | 0.8585 |
| exact_real | real | 130000 | 0.5 ms | 3.476 ms | 10.56x | 8.568 ms | 9.12x | 428.4 ms | 327.5 MiB | 39.8 MiB | 31.1815 |
| flash_pz | pz | 180000 | 0.5 ms | 0.635 ms | 57.77x | 3.845 ms | 20.31x | 192.3 ms | 1955.5 MiB | 41.0 MiB | 0.8585 |
| flash_real | real | 130000 | 0.5 ms | 0.419 ms | 87.57x | 1.563 ms | 49.97x | 78.2 ms | 842.4 MiB | 39.8 MiB | 31.1592 |
| cutn_pz | pz | 180000 | 0.4 ms | 16.139 ms | 2.27x | 36.334 ms | 2.15x | 1816.7 ms | 1153.0 MiB | 41.0 MiB | 0.8585 |
| cutn_real | real | 130000 | 0.4 ms | 3.785 ms | 9.70x | 8.213 ms | 9.51x | 410.6 ms | 327.5 MiB | 39.8 MiB | 31.1815 |