Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance comparison NF4Tensor vs. BNB Params4bit #1686

Open
psinger opened this issue Feb 10, 2025 · 9 comments
Open

Performance comparison NF4Tensor vs. BNB Params4bit #1686

psinger opened this issue Feb 10, 2025 · 9 comments

Comments

@psinger
Copy link

psinger commented Feb 10, 2025

For testing #1684 related to #1665 I have been comparing side to side my QLoRA implementation with bitsandbytes Params4bit vs. torchao NF4Tensor on single GPU or DDP setups.

I noticed that the torchao script nearly has twice the runtime vs. the bitsandbytes implementation. Here is the main difference:

weight = nn.Parameter(
    to_nf4(weight, block_size=64, scaler_block_size=256)
)

linear_nf4(input=x, weight=weight)

vs.

weight = Params4bit(
    data=weight,
    quant_type="nf4",
    blocksize=64
)._quantize(device="cuda:0")

bnb.matmul_4bit(
    x,
    weight.t(),
    quant_state=weight.quant_state,
)

I ran a few simple timings in a quick script, and indeed the forward appears to be nearly twice as slow on simple input and weight tensors in bfloat16. Digging a bit deeper it seems to be mainly based on the dequantization step.

In torchao weight.to(torch.bfloat16) seems to be quite slower vs. bitsandbytes.functional.dequantize_4bit(weight, weight.quant_state, quant_type="nf4").

Im wondering if Im either missing some difference for a fair comparison, or some functionality, or whether this is expected. Happy for any guidance. Thanks!

@supriyar
Copy link
Contributor

cc @drisspg to comment on NF4 performance

@drisspg
Copy link
Contributor

drisspg commented Feb 10, 2025

This subclass was specifically designed to be torch.compile friendly and to recover or exceed performance. Can you try the same script but w/ torch.compile enabled?

@psinger
Copy link
Author

psinger commented Feb 11, 2025

@drisspg
Unfortunately not seeing any speed improvements with torch compile.

@drisspg
Copy link
Contributor

drisspg commented Feb 11, 2025

Can you share the script
You are using for comparing performance

@psinger
Copy link
Author

psinger commented Feb 11, 2025

Pimped up my script a bit, hope there are no bugs:
https://gist.github.com/psinger/4c0be78770d1b84d641e9dab2208c9b0

Likely just something wrong on my side, as the results roughly come close to QLoRA training, where I see around 50% slowdown.

I am using a single L40s for the test, my results:

Summary Statistics:
================================================================================
                         memory_mb               time_ms            time_std_ms
method                   NF4Tensor  Params4bit NF4Tensor Params4bit   NF4Tensor Params4bit
batch_size hidden_dim
1          1024          76.363672   14.127832  0.182146   0.047314    0.015026   0.003152
           2048          92.318848   32.129883  0.174555   0.048649    0.004904   0.004196
           4096         272.894531  104.133789  0.667517   0.052217    0.009201   0.003340
           8192        1067.166016  392.141602  3.729443   0.076748    0.007220   0.018644
8          1024          77.189941   14.141602  0.174030   0.070059    0.008859   0.003127
           2048          92.346191   32.157227  0.179347   0.070407    0.017062   0.003802
           4096         273.003906  104.188477  0.663882   0.110282    0.002846   0.011656
           8192        1067.384766  392.250977  3.768017   0.435565    0.004565   0.020859
32         1024          77.236816   14.188477  0.167706   0.067893    0.002750   0.001696
           2048          92.439941   32.250977  0.179633   0.073396    0.021870   0.011210
           4096         273.378906  104.375977  0.668995   0.111383    0.005665   0.003276
           8192        1068.134766  392.625977  3.787099   0.453304    0.001918   0.007065
128        1024          77.424316   14.375977  0.173956   0.070141    0.012143   0.010937
           2048          92.814941   32.625977  0.178456   0.076830    0.008897   0.001315
           4096         274.878906  105.125977  0.671623   0.115683    0.008129   0.009106
           8192        1071.134766  394.125977  3.799423   0.495808    0.001588   0.007795

Speedup Ratios (Params4bit time / NF4Tensor time):
================================================================================
batch_size  hidden_dim
1           1024          0.259758
            2048          0.278701
            4096          0.078225
            8192          0.020579
8           1024          0.402568
            2048          0.392573
            4096          0.166117
            8192          0.115595
32          1024          0.404836
            2048          0.408588
            4096          0.166493
            8192          0.119697
128         1024          0.403211
            2048          0.430526
            4096          0.172244
            8192          0.130496

@drisspg
Copy link
Contributor

drisspg commented Feb 11, 2025

Hey I made some tweaks to only measusre GPU (kernel time):
https://gist.github.com/drisspg/89255396e2e682798c75b05f63ab0431

As well I set max-autotune and add teh compile limit since thats a common gotcha

Summary Statistics:
================================================================================
                        memory_mb                 time_ms             time_std_ms           
method                  NF4Tensor  Params4bit   NF4Tensor  Params4bit   NF4Tensor Params4bit
batch_size hidden_dim                                                                       
1          1024        250.691797  246.725098   13.039025    9.413793    0.008531   0.050377
           2048        270.334473  254.463867   37.675816   12.320583    0.045331   0.042032
           4096        348.902344  285.412109  171.810753   20.014502    0.248672   0.035495
           8192        661.666016  409.189453  835.910086   51.097615    2.304483   0.387800
8          1024        280.689941  280.736816    8.791172   15.231156    0.003363   0.007602
           2048        294.268066  294.487305   17.917699   24.123021    0.020156   0.011100
           4096        348.519531  349.458984   52.671854   58.791292    0.014061   0.066261
           8192        565.400391  569.283203  229.993638  241.503009    0.033971   0.057382
32         1024        280.783691  280.783691    8.720310   15.160562    0.005351   0.016203
           2048        294.455566  294.581055   18.045767   24.283645    0.010421   0.010635
           4096        348.894531  349.646484   53.511519   59.496724    0.044118   0.051886
           8192        566.150391  569.658203  230.906830  242.464692    0.050207   0.105864
128        1024        281.158691  281.158691    9.032877   15.513316    0.014203   0.036833
           2048        295.205566  295.206055   18.612753   24.917836    0.027866   0.007975
           4096        350.394531  350.396484   80.006724   59.379454    0.041938   0.025266
           8192        569.150391  571.158203  298.323062  241.903570    3.526173   0.472699

Speedup Ratios (Params4bit time / NF4Tensor time):
================================================================================
batch_size  hidden_dim
1           1024          0.721971
            2048          0.327016
            4096          0.116492
            8192          0.061128
8           1024          1.732551
            2048          1.346324
            4096          1.116180
            8192          1.050042
32          1024          1.738535
            2048          1.345670
            4096          1.111849
            8192          1.050054
128         1024          1.717428
            2048          1.338751
            4096          0.742181
            8192          0.810878
Name: ratio, dtype: float64

There is deff a gap at the small sizes, I imagine that is because we are doing any prologue fusion for NF4

cc @eellison would be fun to try w/ the updates

@psinger
Copy link
Author

psinger commented Feb 12, 2025

Thanks @drisspg - some interesting results.

Few observations:

On your benchmark script, changing the compile limit does not change my results by a lot.
However, the compile limit seems quite important in the full QLoRA script, otherwise no speedup. Is this expected?
How to choose the proper limit here?

Also, it seems just compiling linear_nf4 brings a speedup, but compiling more of the model, does not bring any speedup, although linear_nf4 is part of it. For example this forward:
https://github.com/pytorch/torchtune/blob/main/torchtune/modules/peft/lora.py#L120

The memory usage also seems to go up a bit, after compiling.

@drisspg
Copy link
Contributor

drisspg commented Feb 12, 2025

So the compile limit is really a foot gun of benchmarking compile in microbenchmarks. I actually realized I didn't set dynamic=False. So in this case it doing anything but re-running w/ the limit bumped and compile=False shoudl increase perf by some amoutn.

Also, it seems just compiling linear_nf4 brings a speedup, but compiling more of the model, does not bring any speedup, although linear_nf4 is part of it. For example this forward:

that is surprising to me

@psinger
Copy link
Author

psinger commented Feb 12, 2025

So the compile limit is really a foot gun of benchmarking compile in microbenchmarks. I actually realized I didn't set dynamic=False. So in this case it doing anything but re-running w/ the limit bumped and compile=False shoudl increase perf by some amoutn.

Also, it seems just compiling linear_nf4 brings a speedup, but compiling more of the model, does not bring any speedup, although linear_nf4 is part of it. For example this forward:

that is surprising to me

It is also very surprising to me. Here are three of my versions all doing QLoRA on Qwen2.5-7B-Instruct similar to how done in torchtune using NF4Tensor:

First version:

torch._dynamo.config.cache_size_limit = 10000
linear_nf4_compiled = torch.compile(linear_nf4, mode="max-autotune-no-cudagraphs")
out = linear_nf4_compiled(input=x, weight=self.weight)

Runtime: 04:09
This matches BNB runtime without compile.

Second version:
The same as first, but remove the single line for cache size limit increase.
Runtime: 08:00

^^ This already doubles the runtime.
This is similar to the one I get without any compile

Third version:
Putting back the cache limit
Instead of compiling the linear_nf4 function, put decorator on top of the LoraLinear forward
@torch.compile(mode="max-autotune-no-cudagraphs")
Runtime: 04:04

^^ So I was (thankfully) wrong about this in the previous post

Now tested a few cache limits and it seems it needs at least:
torch._dynamo.config.cache_size_limit = 32

Apparently, the default is only 8, thought it would be larger

Now Im wondering why the default is so small and what a good value is to set it to.

I have dynamic input shapes, so this might play a role here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants