Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sgmv_cutlass calculate wrong output #11

Open
harryhan618 opened this issue Nov 17, 2023 · 8 comments
Open

sgmv_cutlass calculate wrong output #11

harryhan618 opened this issue Nov 17, 2023 · 8 comments

Comments

@harryhan618
Copy link

I'm running the following code and find the answer goes wrong. I initialize the x and w to be all ones. So the output y value should be h1=4096.

But my output is not. Half of the output is 4096 and the other half is 2528. Weird!
My observation is that the wrong answer happens when h2>=32 for shrink.

The following code is adapted from benchmarks/bench_sgmv_cutlass.py

import torch
import punica.ops

bs = 4
h1 = 4096
h2 = 32
num_layers = 1
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [2, 2]

w = [
      torch.ones((num_layers, h1, h2), dtype=dtype, device=device)
      for _ in range(len(problem_sizes))
  ]
w_ptr = torch.tensor([t.data_ptr() for t in w],
                     dtype=torch.int64,
                     device=device)
s = torch.cumsum(
    torch.tensor([0] + problem_sizes, device=device),
    dim=0,
    dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
punica.ops.sgmv_cutlass(y, x, w_ptr, s, layer_idx=0)

print(y)
@abcdabcd987
Copy link
Contributor

Hmm... That's interesting... BTW, thanks for providing this script. Super helpful for reproducing the bug!

We'll take a look at this. In the meanwhile, you can use punica.ops.sgmv() for SGMV-shrink and punica.ops.add_lora_sgmv_custom_cutlass() for LoRA. Note that our custom kernel assumes column major weight whereas our cutlass kernel assumes row major weight.

The following works:

import torch
import punica.ops

bs = 4
h1 = 4096
h2 = 32
num_layers = 1
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [2, 2]

w = [
      torch.ones((num_layers, h2, h1), dtype=dtype, device=device)
      for _ in range(len(problem_sizes))
  ]
w_ptr = torch.tensor([t.data_ptr() for t in w],
                     dtype=torch.int64,
                     device=device)
s = torch.cumsum(
    torch.tensor([0] + problem_sizes, device=device),
    dim=0,
    dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0)

print(y)

@harryhan618
Copy link
Author

Thanks for your reply!
I'm curious that why do you choose column major weight? My basic understanding is that row major is friendly for data loading. Sorry I haven't read the kernel code yet.

@yzh119
Copy link
Contributor

yzh119 commented Nov 20, 2023

@harryhan618 modern GPUs support transpose at fragment level (with ldmatrix.***.trans/movmatrix instructions) at very low cost, so there should not be a significant performance difference between column major & row major layout.

We will support row-major for shrink kernel in the next release.

@jcao-ai
Copy link

jcao-ai commented Nov 22, 2023

@abcdabcd987 @yzh119
I also met the case that kernel launch fails under rank == 64 for sgmv_shrink usage:

import torch
import punica.ops

bs = 1
h1 = 1024
h2 = 64
num_layers = 32
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [1]

w = [
      torch.randn((num_layers, h2, h1), dtype=dtype, device=device)
      for _ in range(len(problem_sizes))
  ]

w_ptr = torch.tensor([t.data_ptr() for t in w],
                     dtype=torch.int64,
                     device=device)
s = torch.cumsum(
    torch.tensor([0] + problem_sizes, device=device),
    dim=0,
    dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
# punica.ops.sgmv_cutlass(y, x, w_ptr, s, layer_idx=0)
punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0)

print(y)

Output:

RuntimeError: No suitable kernel. dtype=Half d_out=64

@jcao-ai
Copy link

jcao-ai commented Nov 22, 2023

@abcdabcd987 @yzh119 I also met the case that kernel launch fails under rank == 64 for sgmv_shrink usage:

import torch
import punica.ops

bs = 1
h1 = 1024
h2 = 64
num_layers = 32
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [1]

w = [
      torch.randn((num_layers, h2, h1), dtype=dtype, device=device)
      for _ in range(len(problem_sizes))
  ]

w_ptr = torch.tensor([t.data_ptr() for t in w],
                     dtype=torch.int64,
                     device=device)
s = torch.cumsum(
    torch.tensor([0] + problem_sizes, device=device),
    dim=0,
    dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
# punica.ops.sgmv_cutlass(y, x, w_ptr, s, layer_idx=0)
punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0)

print(y)

Output:

RuntimeError: No suitable kernel. dtype=Half d_out=64

NVM, I found this is related to shared memory. PR: #20

@harryhan618
Copy link
Author

Hi, any updates on why cutlass group gemmed calculate wrong results?

@abcdabcd987
Copy link
Contributor

Hi, any updates on why cutlass group gemmed calculate wrong results?

I just added a few test cases. 0c7cf81

Cutlass only has this problem for shrink. Since we are deprecating cutlass shrink, we probably won't fix this. Before our custom expand lands, you can use punica.add_lora_sgmv_custom_cutlass() for LoRA.

@harryhan618
Copy link
Author

Hi lequn, I think I found the bug of cutlass_shrink.

Please first see cutlass example 24 group gemm. The second parameter for LinearCombination should 128 / cutlass::sizeof_bits<ElementOutput>::value. For dtype float16, this should be 8.
(Although I don't know why this formula?)

In your code, for shrink, you wrote 4. I think this should be bug. For expand, you wrote 8, which is correct.

By the way, to make the code correctly compiled, I have to change Thread Block Shape and Warp Shape to be GemmShape<16, 128, 64> and GemmShape<16, 32, 64>.

So I'm also wondering how to choose these shape? Since that's the key difference between shrink and expand. I'm looking forward to see your insight!

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants