Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton error on AMD GPUs #231

Open
eminorhan opened this issue Sep 8, 2024 · 8 comments
Open

Triton error on AMD GPUs #231

eminorhan opened this issue Sep 8, 2024 · 8 comments

Comments

@eminorhan
Copy link

eminorhan commented Sep 8, 2024

🐛 Describe the bug

I'm trying to test this library on an HPC cluster with AMD MI250X GPUs, but I'm getting a weird seemingly Triton-related error specifically when I turn on model.train(). The following is a minimal example:

import torch
from transformers import AutoModelForCausalLM
from liger_kernel.transformers import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama()
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B", attn_implementation='sdpa', torch_dtype=torch.bfloat16)

x = torch.zeros((4, 128), dtype=int)
x = x.cuda()
model = model.cuda()

model.train()  # runs without an issue when I comment out this line 

y = model(input_ids=x, labels=x)

I get the following error when I run this on an MI250X:

  File "/lustre/orion/stf218/scratch/emin/test_liger/test.py", line 22, in <module>
    y = model(input_ids=x, labels=x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/transformers/model/llama.py", line 109, in lce_forward
    loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/transformers/fused_linear_cross_entropy.py", line 13, in forward
    return LigerFusedLinearCrossEntropyFunction.apply(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 193, in forward
    loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 73, in fused_linear_cross_entropy_forward
    liger_cross_entropy_kernel[(n_rows,)](
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/backends/amd/driver.py", line 418, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]:  Code: 1, Messsage: invalid argument

The same code snippet works fine when I turn off model.train(). I also have access to another cluster with NVIDIA GPUs and I can confirm that it works fine (with or without model.train()) on NVIDIA GPUs (A100 and H100), so this is an AMD-specific issue. I would appreciate any help you could provide for debugging this issue.

Reproduce

No response

Versions

I'm running on PyTorch-nightly + ROCm 6.2 + liger-kernel-nightly:

torch==2.5.0.dev20240906+rocm6.2
triton==3.0.0
liger-kernel-nightly==0.2.1.dev20240908014422
transformers==4.44.2
@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 8, 2024

cc @Jokeren do you have any ideas?

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 8, 2024

same error as triton-lang/triton#4128. also, @helloworld1 has tested on AMD GPUs before. Can you share your experience?

@helloworld1
Copy link
Collaborator

Same error on MI210, not able to resolve it myself. Looks like triton / rocm compatibility issue.

@eminorhan
Copy link
Author

eminorhan commented Sep 8, 2024

@ByronHsu Thanks a lot for the pointer. I really appreciate the help. So, when I manually change the num_warps arguments to 64 in fused_linear_cross_entropy.py, it seems to fix this particular issue, but now I get another error in fused_linear_cross_entropy:

[rank16]:   File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 193, in forward
[rank16]:     loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
[rank16]:                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank16]:   File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 54, in fused_linear_cross_entropy_forward
[rank16]:     logits_chunk = _input_chunk @ weight.t()  # chunk_size x V
[rank16]:                    ~~~~~~~~~~~~~^~~~~~~~~~~~
[rank16]: RuntimeError: size mismatch, got input (16), mat (16x4096), vec (16416768)

I should note that this doesn't use the minimal code above. It's a little bit more complicated with an fsdp wrapper around the model (I couldn't immediately create a minimal example), but I was wondering if you had any ideas as to what might be triggering this size mismatch error.

@DocShotgun
Copy link

DocShotgun commented Sep 12, 2024

I got this same error a few weeks ago trying to train on MI300x with axolotl (on torch 2.4.0+rocm6.1). There was one time I got the training run to start fiddling with various deps, but I could never reproduce that unfortunately.

EDIT: Was able to get rope and rms_norm liger kernels to run without this error for a Llama 3.1 model on the setup I mentioned above. swiglu, cross entropy, and fused linear cross entropy all result in this error, in case that helps narrow anything down a little.

@Jokeren
Copy link

Jokeren commented Sep 12, 2024

I don't maintain the AMD backend. Better to try out triton/main or contact AMD people

@DocShotgun
Copy link

Following the logic in the issue linked here (#231 (comment)), noting that the warp size of AMD Instinct processors is 64 compared to 32 for NVIDIA GPUs, I halved num_warps across the board in my fork of liger kernel (DocShotgun@81db02c). This appears to solve the problem for me, allowing training on MI300x on a llama 3.1 8b architecture model.

Training appears to be working fine judging by my logs (slightly faster and significantly less memory while having similar loss and grad norms):

No liger:

{'loss': 1.3565, 'grad_norm': 26.5, 'learning_rate': 2.5000000000000004e-07, 'epoch': 0.0}                                                                                        
[axolotl.callbacks.on_step_end:128] [PID:8168] [RANK:0] GPU memory usage while training: 29.945GB (+50.608GB cache)
{'loss': 1.3419, 'grad_norm': 29.25, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.01}                                                                                       
{'loss': 1.3337, 'grad_norm': 25.625, 'learning_rate': 7.5e-07, 'epoch': 0.01}                                                                                                    
{'loss': 1.3301, 'grad_norm': 25.875, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.02}                                                                                     
{'loss': 1.2617, 'grad_norm': 24.375, 'learning_rate': 1.25e-06, 'epoch': 0.02}                                                                                                   
{'loss': 1.2703, 'grad_norm': 22.75, 'learning_rate': 1.5e-06, 'epoch': 0.02}                                                                                                     
{'loss': 1.2243, 'grad_norm': 21.375, 'learning_rate': 1.75e-06, 'epoch': 0.03}                                                                                                   
{'loss': 1.3376, 'grad_norm': 19.0, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.03}                                                                                       
{'loss': 1.287, 'grad_norm': 15.125, 'learning_rate': 2.25e-06, 'epoch': 0.04}                                                                                                    
{'loss': 1.2013, 'grad_norm': 10.0, 'learning_rate': 2.5e-06, 'epoch': 0.04}                                                                                                      
{'loss': 1.1887, 'grad_norm': 8.375, 'learning_rate': 2.7500000000000004e-06, 'epoch': 0.04}                                                                                      
  2%|███                                                                                                                                       | 11/506 [02:10<1:36:51, 11.74s/it]

With liger:

{'loss': 1.3556, 'grad_norm': 26.375, 'learning_rate': 2.5000000000000004e-07, 'epoch': 0.0}                                                                                      
[axolotl.callbacks.on_step_end:128] [PID:8571] [RANK:0] GPU memory usage while training: 29.948GB (+24.158GB cache)
{'loss': 1.3414, 'grad_norm': 29.25, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.01}                                                                                       
{'loss': 1.334, 'grad_norm': 25.625, 'learning_rate': 7.5e-07, 'epoch': 0.01}                                                                                                     
{'loss': 1.3305, 'grad_norm': 26.375, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.02}                                                                                     
{'loss': 1.262, 'grad_norm': 24.25, 'learning_rate': 1.25e-06, 'epoch': 0.02}                                                                                                     
{'loss': 1.2709, 'grad_norm': 23.0, 'learning_rate': 1.5e-06, 'epoch': 0.02}                                                                                                      
{'loss': 1.2239, 'grad_norm': 21.25, 'learning_rate': 1.75e-06, 'epoch': 0.03}                                                                                                    
{'loss': 1.3373, 'grad_norm': 18.375, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.03}                                                                                     
{'loss': 1.2871, 'grad_norm': 14.8125, 'learning_rate': 2.25e-06, 'epoch': 0.04}                                                                                                  
{'loss': 1.2012, 'grad_norm': 10.125, 'learning_rate': 2.5e-06, 'epoch': 0.04}                                                                                                    
{'loss': 1.1885, 'grad_norm': 8.25, 'learning_rate': 2.7500000000000004e-06, 'epoch': 0.04}                                                                                       
  2%|███                                                                                                                                       | 11/506 [01:54<1:24:52, 10.29s/it]

I know nothing about triton kernels, so I wanted to ask if there are any potential adverse consequences to this?
And if not, would it be sufficient to simply correct the num_warps by half when an AMD Instinct processor is detected?

@eminorhan
Copy link
Author

Ha! I had set num_warps=64, but I think you're right that it should have been 16 instead (I mixed up num_warps with warp_size).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants