-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch compiled FLCE is 2x faster than the current FLCE #227
Comments
Actually, if we align the CHUNK_SIZE of the Torch-compiled FLCE with the strategy used in Liger's FLCE, the compiled version is only slightly faster than the Liger version, but it does require a bit more memory as well. The advantage of the Torch-compiled version is its flexibility; implementing the Gemma2 softcap logits is very straightforward, whereas I struggled for some time to achieve consistent accuracy with this in Liger. |
@wizyoung how are you setting the chunk size? I wasn't able to get the liger kernel to perform much better even when changing the chunk size. |
@Chillee By referencing https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py#L23. I mean, chaning chunk size in torch compiled FLCE. Your default chunk size is 1024, and I change to 256. Then I have: env: torch2.3.1, triton2.3.1, A100 80G, cuda12.3 |
I have done some quick tests with different B, T, D and V to mimic my training conditions(llama3 and gemma2) in my env, my conclusion is that torch compiled flce is indeed faster, but has worse memory management. |
https://gist.github.com/wizyoung/5330ad501e73a97dfe2f0088decdb1ca |
@wizyoung I agree there's some additional memory overhead (in particular, I think we don't inplace the addmm), but the additional memory is generally pretty negligible here, no? For example, if I change the chunk size from 256 to 512, torch.compile performance improves from 186ms down to 153, while memory only increases from 1.48 GB to 1.54 GB. If I try increasing the chunk size of Liger, it doesn't seem to increase the performance as much as the torch.compile version |
Curious how this compares with JonasGeiping/linear_cross_entropy_loss , but |
@ekojsalim In my brief testing, it seems like it's both faster and uses less memory. |
Yes, the increase in memory usage is generally negligible. My primary concern is the running time overhead, specifically that the B*T varies significantly and is not a multiple of the chunk size, leading to frequent calls of recompile(I add |
@wizyoung Can you post your benchmark script? |
@Chillee I have updated my scripts here: https://gist.github.com/wizyoung/5330ad501e73a97dfe2f0088decdb1ca |
🚀 The feature, motivation and pitch
We can leverage torch compile to fuse the things we cannot fuse now like upcasting, contiguous call, etc
Sample code: https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py
Provided by the brilliant @Chillee
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: