Skip to content

Commit

Permalink
update sgmv test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
abcdabcd987 committed Nov 27, 2023
1 parent 51cd092 commit 87cb9f5
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions tests/test_sgmv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import torch

Expand All @@ -16,7 +17,7 @@ def sgmv_ref_impl(
y: torch.Tensor,
x: torch.Tensor,
w: list[torch.Tensor],
s: torch.IntTensor,
s: torch.Tensor,
layer_idx: int,
):
for i in range(len(w)):
Expand All @@ -26,6 +27,33 @@ def sgmv_ref_impl(
y[s[i] : s[i + 1]] = (yi + xi @ wi).to(y.dtype)


def get_lora_lens(bs: int, popularity: str) -> list[int]:
if popularity == "identical":
return [bs]
if popularity == "distinct":
return [1] * bs
if popularity == "uniform":
n = int(np.ceil(np.sqrt(bs)))
lens = np.array([bs // n] * n)
while True:
diff = bs - lens.sum()
if diff == 0:
break
lens[: abs(diff)] += np.sign(diff)
return lens.tolist()
if popularity.startswith("zipf:"):
alpha = float(popularity.split(":")[1])
assert alpha > 1
lens = []
a = 1
while sum(lens) + int(np.floor(a)) < bs:
lens.append(int(np.floor(a)))
a *= alpha
lens.append(bs - sum(lens))
return sorted(lens, reverse=True)
raise KeyError(popularity)


def lora_ref_impl(
y: torch.Tensor,
x: torch.Tensor,
Expand Down Expand Up @@ -53,11 +81,12 @@ def lora_ref_impl(
pytest.param("expand", marks=pytest.mark.xfail(reason="TODO: sgmv expand")),
],
)
@pytest.mark.parametrize("batch_setup", ["1x7", "7x1", "3x3", "32x1", "1x32"])
@pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical"])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 7, 10, 16, 32, 64, 133])
@torch.inference_mode()
def test_sgmv_correctness(dtype_str, h, r, direction, batch_setup):
def test_sgmv_correctness(dtype_str, h, r, direction, popularity, batch_size):
torch.manual_seed(0xABCDABCD987)
num_problems, problem_size = map(int, batch_setup.split("x"))
seqlens = get_lora_lens(batch_size, popularity)
num_layers = 5
dtype = getattr(torch, dtype_str)
device = torch.device("cuda:0")
Expand All @@ -68,16 +97,16 @@ def test_sgmv_correctness(dtype_str, h, r, direction, batch_setup):

w = [
torch.randn((num_layers, h2, h1), dtype=dtype, device=device)
for _ in range(num_problems)
for _ in range(len(seqlens))
]
w_ptr = torch.tensor([t.data_ptr() for t in w], dtype=torch.int64, device=device)
s = torch.cumsum(
torch.tensor([0] + [problem_size] * num_problems, device=device),
torch.tensor([0] + seqlens, device=device),
dim=0,
dtype=torch.int32,
)
x = torch.randn((s[-1], h1), dtype=dtype, device=device)
y = torch.randn((s[-1], h2), dtype=dtype, device=device)
x = torch.randn((int(s[-1]), h1), dtype=dtype, device=device)
y = torch.randn((int(s[-1]), h2), dtype=dtype, device=device)
for layer_idx in range(num_layers):
y_ref = y.clone()
sgmv_ref_impl(y_ref, x, w, s, layer_idx)
Expand Down

0 comments on commit 87cb9f5

Please sign in to comment.