-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance comparison NF4Tensor
vs. BNB Params4bit
#1686
Comments
cc @drisspg to comment on NF4 performance |
This subclass was specifically designed to be torch.compile friendly and to recover or exceed performance. Can you try the same script but w/ torch.compile enabled? |
@drisspg |
Can you share the script |
Pimped up my script a bit, hope there are no bugs: Likely just something wrong on my side, as the results roughly come close to QLoRA training, where I see around 50% slowdown. I am using a single L40s for the test, my results:
|
Hey I made some tweaks to only measusre GPU (kernel time): As well I set max-autotune and add teh compile limit since thats a common gotcha Summary Statistics:
================================================================================
memory_mb time_ms time_std_ms
method NF4Tensor Params4bit NF4Tensor Params4bit NF4Tensor Params4bit
batch_size hidden_dim
1 1024 250.691797 246.725098 13.039025 9.413793 0.008531 0.050377
2048 270.334473 254.463867 37.675816 12.320583 0.045331 0.042032
4096 348.902344 285.412109 171.810753 20.014502 0.248672 0.035495
8192 661.666016 409.189453 835.910086 51.097615 2.304483 0.387800
8 1024 280.689941 280.736816 8.791172 15.231156 0.003363 0.007602
2048 294.268066 294.487305 17.917699 24.123021 0.020156 0.011100
4096 348.519531 349.458984 52.671854 58.791292 0.014061 0.066261
8192 565.400391 569.283203 229.993638 241.503009 0.033971 0.057382
32 1024 280.783691 280.783691 8.720310 15.160562 0.005351 0.016203
2048 294.455566 294.581055 18.045767 24.283645 0.010421 0.010635
4096 348.894531 349.646484 53.511519 59.496724 0.044118 0.051886
8192 566.150391 569.658203 230.906830 242.464692 0.050207 0.105864
128 1024 281.158691 281.158691 9.032877 15.513316 0.014203 0.036833
2048 295.205566 295.206055 18.612753 24.917836 0.027866 0.007975
4096 350.394531 350.396484 80.006724 59.379454 0.041938 0.025266
8192 569.150391 571.158203 298.323062 241.903570 3.526173 0.472699
Speedup Ratios (Params4bit time / NF4Tensor time):
================================================================================
batch_size hidden_dim
1 1024 0.721971
2048 0.327016
4096 0.116492
8192 0.061128
8 1024 1.732551
2048 1.346324
4096 1.116180
8192 1.050042
32 1024 1.738535
2048 1.345670
4096 1.111849
8192 1.050054
128 1024 1.717428
2048 1.338751
4096 0.742181
8192 0.810878
Name: ratio, dtype: float64 There is deff a gap at the small sizes, I imagine that is because we are doing any prologue fusion for NF4 cc @eellison would be fun to try w/ the updates |
Thanks @drisspg - some interesting results. Few observations: On your benchmark script, changing the compile limit does not change my results by a lot. Also, it seems just compiling The memory usage also seems to go up a bit, after compiling. |
So the compile limit is really a foot gun of benchmarking compile in microbenchmarks. I actually realized I didn't set dynamic=False. So in this case it doing anything but re-running w/ the limit bumped and compile=False shoudl increase perf by some amoutn.
that is surprising to me |
It is also very surprising to me. Here are three of my versions all doing QLoRA on First version:
Runtime: 04:09 Second version: ^^ This already doubles the runtime. Third version: ^^ So I was (thankfully) wrong about this in the previous post Now tested a few cache limits and it seems it needs at least: Apparently, the default is only 8, thought it would be larger Now Im wondering why the default is so small and what a good value is to set it to. I have dynamic input shapes, so this might play a role here. |
For testing #1684 related to #1665 I have been comparing side to side my QLoRA implementation with bitsandbytes
Params4bit
vs. torchaoNF4Tensor
on single GPU or DDP setups.I noticed that the torchao script nearly has twice the runtime vs. the bitsandbytes implementation. Here is the main difference:
vs.
I ran a few simple timings in a quick script, and indeed the forward appears to be nearly twice as slow on simple input and weight tensors in
bfloat16
. Digging a bit deeper it seems to be mainly based on the dequantization step.In torchao
weight.to(torch.bfloat16)
seems to be quite slower vs.bitsandbytes.functional.dequantize_4bit(weight, weight.quant_state, quant_type="nf4")
.Im wondering if Im either missing some difference for a fair comparison, or some functionality, or whether this is expected. Happy for any guidance. Thanks!
The text was updated successfully, but these errors were encountered: